Added country specific prompts and more async to speed up

This commit is contained in:
AlxAI 2025-08-02 14:48:03 -04:00
parent 9fc25f2fec
commit 3b5f3015c1
7 changed files with 225 additions and 81 deletions

1
.gitignore vendored
View file

@ -165,3 +165,4 @@ analysis_summary_debug.txt
./results_alpha
/results_alpha/20250607_222757
/ai_diplomacy/prompts/famous_leaders_prompts

View file

@ -6,6 +6,7 @@ import re
import json_repair
import json5 # More forgiving JSON parser
import ast
import asyncio
from config import config

View file

@ -22,7 +22,7 @@ from together.error import APIError as TogetherAPIError # For specific error ha
from config import config
from .game_history import GameHistory
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
from .utils import load_prompt, run_llm_and_log, log_llm_response, log_llm_response_async, generate_random_seed, get_prompt_path
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
@ -52,6 +52,7 @@ class BaseModelClient:
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
self.model_name = model_name
self.prompts_dir = prompts_dir
logger.info(f"[{model_name}] BaseModelClient initialized with prompts_dir: {prompts_dir}")
# Load a default initially, can be overwritten by set_system_prompt
self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
self.max_tokens = 16000 # default unless overridden
@ -180,7 +181,7 @@ class BaseModelClient:
finally:
# Log the attempt regardless of outcome
if log_file_path: # Only log if a path is provided
log_llm_response(
await log_llm_response_async(
log_file_path=log_file_path,
model_name=self.model_name,
power_name=power_name,
@ -441,7 +442,18 @@ class BaseModelClient:
agent_private_diary_str: Optional[str] = None, # Added
) -> str:
# MINIMAL CHANGE: Just change to load unformatted version conditionally
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
# Check if country-specific prompts are enabled
if config.COUNTRY_SPECIFIC_PROMPTS:
# Try to load country-specific version first
country_specific_file = get_prompt_path(f"conversation_instructions_{power_name.lower()}.txt")
instructions = load_prompt(country_specific_file, prompts_dir=self.prompts_dir)
# Fall back to generic if country-specific not found
if not instructions:
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
else:
# Load generic conversation instructions
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
# KEEP ORIGINAL: Use build_context_prompt as before
context = build_context_prompt(
@ -670,7 +682,7 @@ class BaseModelClient:
messages_to_return = [] # Ensure empty list on general exception
finally:
if log_file_path:
log_llm_response(
await log_llm_response_async(
log_file_path=log_file_path,
model_name=self.model_name,
power_name=power_name,
@ -749,7 +761,7 @@ class BaseModelClient:
plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}"
finally:
if log_file_path: # Only log if a path is provided
log_llm_response(
await log_llm_response_async(
log_file_path=log_file_path,
model_name=self.model_name,
power_name=power_name,
@ -797,27 +809,34 @@ class OpenAIClient(BaseModelClient):
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
# Determine which parameter to use based on model
completion_params = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta},
],
}
# Handle model-specific parameters
# Check if model name starts with 'nectarine' or is in the specific list
uses_max_completion_tokens = (
self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or
self.model_name.startswith("nectarine")
)
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
response = await self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta},
],
max_completion_tokens=self.max_tokens,
)
if uses_max_completion_tokens:
completion_params["max_completion_tokens"] = self.max_tokens
# o4-mini, o3-mini, o3 only support default temperature of 1.0
if self.model_name in ["o4-mini", "o3-mini", "o3"]:
completion_params["temperature"] = 1.0
else:
completion_params["temperature"] = temperature
else:
response = await self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta},
],
temperature=temperature,
max_tokens=self.max_tokens,
)
completion_params["max_tokens"] = self.max_tokens
completion_params["temperature"] = temperature
response = await self.client.chat.completions.create(**completion_params)
if (
not response
@ -971,16 +990,24 @@ class DeepSeekClient(BaseModelClient):
random_seed = generate_random_seed()
system_prompt_content = f"{random_seed}\n\n{self.system_prompt}"
response = await self.client.chat.completions.create(
model=self.model_name,
messages=[
# Determine which parameter to use based on model
completion_params = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta},
],
stream=False,
temperature=temperature,
max_tokens=self.max_tokens,
)
"stream": False,
"temperature": temperature,
}
# Use max_completion_tokens for o4-mini, o3-mini models and nectarine models
if self.model_name in ["o4-mini", "o3-mini"] or self.model_name.startswith("nectarine"):
completion_params["max_completion_tokens"] = self.max_tokens
else:
completion_params["max_tokens"] = self.max_tokens
response = await self.client.chat.completions.create(**completion_params)
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
@ -1023,7 +1050,7 @@ class OpenAIResponsesClient(BaseModelClient):
This client makes direct HTTP requests to the v1/responses endpoint.
"""
def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None):
def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None, reasoning_effort: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
if api_key:
self.api_key = api_key
@ -1032,7 +1059,20 @@ class OpenAIResponsesClient(BaseModelClient):
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable is required")
self.base_url = "https://api.openai.com/v1/responses"
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client")
self._session = None # Lazy initialization for connection pooling
self.reasoning_effort = reasoning_effort # For models that support reasoning effort
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client with reasoning_effort={reasoning_effort}")
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create the aiohttp session for connection pooling."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self):
"""Close the aiohttp session."""
if self._session and not self._session.closed:
await self._session.close()
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
try:
@ -1049,51 +1089,59 @@ class OpenAIResponsesClient(BaseModelClient):
payload = {
"model": self.model_name,
"input": full_prompt,
"temperature": temperature,
"max_tokens": self.max_tokens,
}
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
del payload["temperature"]
del payload["max_tokens"]
payload["max_completion_tokens"] = self.max_tokens
# The Responses API uses max_output_tokens for all models
payload["max_output_tokens"] = self.max_tokens
# Only add temperature for models that support it
models_without_temp = ['o3', 'o4-mini', 'gpt-5-reasoning-alpha-2025-07-19', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
if self.model_name not in models_without_temp:
payload["temperature"] = temperature
# Add reasoning effort for models that support it
reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'o4-mini-alpha-2025-07-11', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
if self.reasoning_effort and self.model_name in reasoning_models:
payload["reasoning"] = {"effort": self.reasoning_effort}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
# Make the API call using aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(self.base_url, json=payload, headers=headers) as response:
response.raise_for_status() # Will raise for non-2xx responses
response_data = await response.json()
# Make the API call using the pooled session
session = await self._get_session()
async with session.post(self.base_url, json=payload, headers=headers) as response:
response.raise_for_status() # Will raise for non-2xx responses
response_data = await response.json()
# Extract the text from the nested response structure
try:
outputs = response_data.get("output", [])
if len(outputs) < 2:
raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.")
# Extract the text from the nested response structure
try:
outputs = response_data.get("output", [])
if len(outputs) < 2:
# Log the actual response for debugging
logger.error(f"[{self.model_name}] Response structure: {json.dumps(response_data, indent=2)}")
raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.")
message_output = outputs[1]
if message_output.get("type") != "message":
raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.")
message_output = outputs[1]
if message_output.get("type") != "message":
raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.")
content_list = message_output.get("content", [])
if not content_list:
raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.")
content_list = message_output.get("content", [])
if not content_list:
raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.")
text_content = ""
for content_item in content_list:
if content_item.get("type") == "output_text":
text_content = content_item.get("text", "")
break
text_content = ""
for content_item in content_list:
if content_item.get("type") == "output_text":
text_content = content_item.get("text", "")
break
if not text_content:
raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.")
if not text_content:
raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.")
return text_content.strip()
return text_content.strip()
except (KeyError, IndexError, TypeError) as e:
# Wrap parsing error in a more informative exception
raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e
except (KeyError, IndexError, TypeError) as e:
# Wrap parsing error in a more informative exception
raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e
except aiohttp.ClientError as e:
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
@ -1302,9 +1350,14 @@ class RequestsOpenAIClient(BaseModelClient):
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
],
"temperature": temperature,
"max_tokens": self.max_tokens,
}
# Use max_completion_tokens for o4-mini, o3-mini, o3, gpt-4.1 models and nectarine models
if self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or self.model_name.startswith("nectarine"):
payload["max_completion_tokens"] = self.max_tokens
else:
payload["max_tokens"] = self.max_tokens
#if self.model_name == "qwen/qwen3-235b-a22b" and self.base_url == "https://openrouter.ai/api/v1":
# payload["provider"] = {
# "order": ["Cerebras"], # fast qwen-2-35B
@ -1313,7 +1366,8 @@ class RequestsOpenAIClient(BaseModelClient):
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
del payload["temperature"]
del payload["max_tokens"]
if "max_tokens" in payload:
del payload["max_tokens"]
payload["max_completion_tokens"] = self.max_tokens
loop = asyncio.get_running_loop()
@ -1381,13 +1435,33 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
gpt-4o
anthropic:claude-3.7-sonnet
openai:llama-3-2-3b@https://localhost:8000#myapikey
gpt-5-reasoning-alpha-2025-07-19:minimal
and returns the appropriate client.
If a prefix is omitted the function falls back to the original
heuristic mapping exactly as before.
If an inline API-key (#…’) is present it overrides environment vars.
If an inline API-key ('#…') is present it overrides environment vars.
For reasoning models, effort can be specified with :minimal, :medium, or :high
"""
spec = _parse_model_spec(model_id)
# Extract reasoning effort if present (before general parsing)
reasoning_effort = None
actual_model_id = model_id
# Check if this is a reasoning model with effort specified
reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
for model in reasoning_models:
if model_id.startswith(model + ':'):
parts = model_id.split(':', 1)
effort_part = parts[1]
# Check if the effort part is valid before treating it as effort
# (it could be a prefix like "openai:")
if effort_part.lower() in ['minimal', 'medium', 'high']:
actual_model_id = parts[0]
reasoning_effort = effort_part.lower()
break
spec = _parse_model_spec(actual_model_id)
logger.info(f"[load_model_client] Loading client for model_id='{model_id}', parsed spec: prefix={spec.prefix}, model={spec.model}, reasoning_effort={reasoning_effort}")
# Inline key overrides env; otherwise fall back as usual *per client*
inline_key = spec.key
@ -1421,7 +1495,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
api_key=inline_key,
)
case Prefix.OPENAI_RESPONSES:
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort)
case Prefix.ANTHROPIC:
return ClaudeClient(spec.model, prompts_dir)
case Prefix.GEMINI:
@ -1437,27 +1511,41 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
# 2. Heuristic fallback path (identical to the original behaviour) #
# ------------------------------------------------------------------ #
lower_id = spec.model.lower()
logger.info(f"[load_model_client] Heuristic path: checking model='{spec.model}', lower_id='{lower_id}'")
# Check if this is a reasoning model that should use Responses API
reasoning_models_requiring_responses = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
if spec.model in reasoning_models_requiring_responses:
logger.info(f"[load_model_client] Selected OpenAIResponsesClient for reasoning model '{spec.model}'")
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort)
if lower_id == "o3-pro":
logger.info(f"[load_model_client] Selected OpenAIResponsesClient for '{spec.model}'")
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
if spec.model.startswith("together-"):
# e.g. "together-mixtral-8x7b"
logger.info(f"[load_model_client] Selected TogetherAIClient for '{spec.model}'")
return TogetherAIClient(spec.model.split("together-", 1)[1], prompts_dir)
if "openrouter" in lower_id:
logger.info(f"[load_model_client] Selected OpenRouterClient for '{spec.model}'")
return OpenRouterClient(spec.model, prompts_dir)
if "claude" in lower_id:
logger.info(f"[load_model_client] Selected ClaudeClient for '{spec.model}'")
return ClaudeClient(spec.model, prompts_dir)
if "gemini" in lower_id:
logger.info(f"[load_model_client] Selected GeminiClient for '{spec.model}'")
return GeminiClient(spec.model, prompts_dir)
if "deepseek" in lower_id:
logger.info(f"[load_model_client] Selected DeepSeekClient for '{spec.model}'")
return DeepSeekClient(spec.model, prompts_dir)
# Default: OpenAI-compatible async client
logger.info(f"[load_model_client] No specific match found, using default OpenAIClient for '{spec.model}'")
return OpenAIClient(
model_name=spec.model,
prompts_dir=prompts_dir,

View file

@ -35,19 +35,23 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
}
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None, override_max_tokens: Optional[int] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
If *override_model_id* is provided (e.g. because the CLI argument
``--models`` was used when resuming a game), that model is loaded
instead of the one stored in the save file.
If *override_max_tokens* is provided (e.g. because the CLI argument
``--max_tokens`` was used when resuming a game), that value is used
instead of the one stored in the save file.
"""
model_id = override_model_id or agent_data["model_id"]
client = load_model_client(model_id, prompts_dir=prompts_dir)
# Keep the original or fallback token limit exactly as before.
client.max_tokens = agent_data.get("max_tokens", 16000)
# Use override if provided, otherwise use saved value, otherwise default to 16000
client.max_tokens = override_max_tokens or agent_data.get("max_tokens", 16000)
agent = DiplomacyAgent(
power_name=agent_data["power_name"],
@ -208,9 +212,22 @@ def load_game_state(
# --- Rebuild agents -------------------------------------------------------
agents: Dict[str, "DiplomacyAgent"] = {}
power_model_map: Dict[str, str] = {}
powers_order = sorted(list(ALL_POWERS))
# Parse token limits from run_config
default_max_tokens = run_config.max_tokens if run_config and hasattr(run_config, 'max_tokens') else 16000
model_max_tokens = {p: default_max_tokens for p in powers_order}
if run_config and hasattr(run_config, 'max_tokens_per_model') and run_config.max_tokens_per_model:
per_model_values = [s.strip() for s in run_config.max_tokens_per_model.split(",")]
if len(per_model_values) == 7:
for power, token_val_str in zip(powers_order, per_model_values):
model_max_tokens[power] = int(token_val_str)
else:
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
@ -234,6 +251,7 @@ def load_game_state(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
override_max_tokens=model_max_tokens.get(power_name),
)
# --- Rebuild GameHistory --------------------------------------------------

View file

@ -175,16 +175,31 @@ def construct_order_generation_prompt(
_ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic
# Pick the phase-specific instruction file (using unformatted versions)
phase_code = board_state["phase"][-1] # 'M' (movement), 'R', or 'A' / 'B'
if phase_code == "M":
instructions_file = get_prompt_path("order_instructions_movement_phase.txt")
elif phase_code in ("A", "B"): # builds / adjustments
instructions_file = get_prompt_path("order_instructions_adjustment_phase.txt")
elif phase_code == "R": # retreats
instructions_file = get_prompt_path("order_instructions_retreat_phase.txt")
else: # unexpected default to movement rules
instructions_file = get_prompt_path("order_instructions_movement_phase.txt")
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
# Determine base instruction file name
if phase_code == "M":
base_instruction_file = "order_instructions_movement_phase"
elif phase_code in ("A", "B"): # builds / adjustments
base_instruction_file = "order_instructions_adjustment_phase"
elif phase_code == "R": # retreats
base_instruction_file = "order_instructions_retreat_phase"
else: # unexpected default to movement rules
base_instruction_file = "order_instructions_movement_phase"
# Check if country-specific prompts are enabled
if config.COUNTRY_SPECIFIC_PROMPTS:
# Try to load country-specific version first
country_specific_file = get_prompt_path(f"{base_instruction_file}_{power_name.lower()}.txt")
instructions = load_prompt(country_specific_file, prompts_dir=prompts_dir)
# Fall back to generic if country-specific not found
if not instructions:
instructions_file = get_prompt_path(f"{base_instruction_file}.txt")
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
else:
# Load generic instruction file
instructions_file = get_prompt_path(f"{base_instruction_file}.txt")
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
_use_simple = config.SIMPLE_PROMPTS
include_order_history = False # defaulting to not include order history in order generation prompt for now

View file

@ -11,6 +11,7 @@ class Configuration(BaseSettings):
log_file_path: Path | None = None
USE_UNFORMATTED_PROMPTS: bool = False
SIMPLE_PROMPTS: bool = True
COUNTRY_SPECIFIC_PROMPTS: bool = False
# Default models for tasks
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"

View file

@ -173,6 +173,18 @@ def parse_arguments():
"Set to false (0 / false / no) to use original single-step formatted prompts."
),
)
parser.add_argument(
"--country_specific_prompts",
type=_str2bool,
nargs="?",
const=True,
default=False,
help=(
"When true (1 / true / yes) enables country-specific order and conversation prompts. "
"Each power will use their own custom prompts if available (e.g., order_instructions_movement_phase_france.txt). "
"Falls back to generic prompts if country-specific not found."
),
)
return parser.parse_args()
@ -224,6 +236,14 @@ async def main():
config.USE_UNFORMATTED_PROMPTS = False
logger.info("Using original single-step formatted prompts")
# Handle country-specific prompts flag
if args.country_specific_prompts:
config.COUNTRY_SPECIFIC_PROMPTS = True
logger.info("Country-specific prompts enabled - powers will use their custom prompts when available")
else:
config.COUNTRY_SPECIFIC_PROMPTS = False
logger.info("Using generic prompts for all powers")
if args.max_year == None:
if args.end_at_phase:
# infer the max year