diff --git a/.gitignore b/.gitignore index 24e2fa1..0217b41 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ analysis_summary_debug.txt ./results_alpha /results_alpha/20250607_222757 +/ai_diplomacy/prompts/famous_leaders_prompts diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index c2ca7fa..3790240 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -6,6 +6,7 @@ import re import json_repair import json5 # More forgiving JSON parser import ast +import asyncio from config import config diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index 12b5d0b..4b774c4 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -22,7 +22,7 @@ from together.error import APIError as TogetherAPIError # For specific error ha from config import config from .game_history import GameHistory -from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path +from .utils import load_prompt, run_llm_and_log, log_llm_response, log_llm_response_async, generate_random_seed, get_prompt_path # Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible from .prompt_constructor import construct_order_generation_prompt, build_context_prompt @@ -52,6 +52,7 @@ class BaseModelClient: def __init__(self, model_name: str, prompts_dir: Optional[str] = None): self.model_name = model_name self.prompts_dir = prompts_dir + logger.info(f"[{model_name}] BaseModelClient initialized with prompts_dir: {prompts_dir}") # Load a default initially, can be overwritten by set_system_prompt self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir) self.max_tokens = 16000 # default unless overridden @@ -180,7 +181,7 @@ class BaseModelClient: finally: # Log the attempt regardless of outcome if log_file_path: # Only log if a path is provided - log_llm_response( + await log_llm_response_async( log_file_path=log_file_path, model_name=self.model_name, power_name=power_name, @@ -441,7 +442,18 @@ class BaseModelClient: agent_private_diary_str: Optional[str] = None, # Added ) -> str: # MINIMAL CHANGE: Just change to load unformatted version conditionally - instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir) + # Check if country-specific prompts are enabled + if config.COUNTRY_SPECIFIC_PROMPTS: + # Try to load country-specific version first + country_specific_file = get_prompt_path(f"conversation_instructions_{power_name.lower()}.txt") + instructions = load_prompt(country_specific_file, prompts_dir=self.prompts_dir) + + # Fall back to generic if country-specific not found + if not instructions: + instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir) + else: + # Load generic conversation instructions + instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir) # KEEP ORIGINAL: Use build_context_prompt as before context = build_context_prompt( @@ -670,7 +682,7 @@ class BaseModelClient: messages_to_return = [] # Ensure empty list on general exception finally: if log_file_path: - log_llm_response( + await log_llm_response_async( log_file_path=log_file_path, model_name=self.model_name, power_name=power_name, @@ -749,7 +761,7 @@ class BaseModelClient: plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}" finally: if log_file_path: # Only log if a path is provided - log_llm_response( + await log_llm_response_async( log_file_path=log_file_path, model_name=self.model_name, power_name=power_name, @@ -797,27 +809,34 @@ class OpenAIClient(BaseModelClient): system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:" + # Determine which parameter to use based on model + completion_params = { + "model": self.model_name, + "messages": [ + {"role": "system", "content": system_prompt_content}, + {"role": "user", "content": prompt_with_cta}, + ], + } - - if (self.model_name == 'o3' or self.model_name == 'o4-mini'): - response = await self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": prompt_with_cta}, - ], - max_completion_tokens=self.max_tokens, - ) + # Handle model-specific parameters + # Check if model name starts with 'nectarine' or is in the specific list + uses_max_completion_tokens = ( + self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or + self.model_name.startswith("nectarine") + ) + + if uses_max_completion_tokens: + completion_params["max_completion_tokens"] = self.max_tokens + # o4-mini, o3-mini, o3 only support default temperature of 1.0 + if self.model_name in ["o4-mini", "o3-mini", "o3"]: + completion_params["temperature"] = 1.0 + else: + completion_params["temperature"] = temperature else: - response = await self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": prompt_with_cta}, - ], - temperature=temperature, - max_tokens=self.max_tokens, - ) + completion_params["max_tokens"] = self.max_tokens + completion_params["temperature"] = temperature + + response = await self.client.chat.completions.create(**completion_params) if ( not response @@ -971,16 +990,24 @@ class DeepSeekClient(BaseModelClient): random_seed = generate_random_seed() system_prompt_content = f"{random_seed}\n\n{self.system_prompt}" - response = await self.client.chat.completions.create( - model=self.model_name, - messages=[ + # Determine which parameter to use based on model + completion_params = { + "model": self.model_name, + "messages": [ {"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}, ], - stream=False, - temperature=temperature, - max_tokens=self.max_tokens, - ) + "stream": False, + "temperature": temperature, + } + + # Use max_completion_tokens for o4-mini, o3-mini models and nectarine models + if self.model_name in ["o4-mini", "o3-mini"] or self.model_name.startswith("nectarine"): + completion_params["max_completion_tokens"] = self.max_tokens + else: + completion_params["max_tokens"] = self.max_tokens + + response = await self.client.chat.completions.create(**completion_params) logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}") @@ -1023,7 +1050,7 @@ class OpenAIResponsesClient(BaseModelClient): This client makes direct HTTP requests to the v1/responses endpoint. """ - def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None): + def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None, reasoning_effort: Optional[str] = None): super().__init__(model_name, prompts_dir=prompts_dir) if api_key: self.api_key = api_key @@ -1032,7 +1059,20 @@ class OpenAIResponsesClient(BaseModelClient): if not self.api_key: raise ValueError("OPENAI_API_KEY environment variable is required") self.base_url = "https://api.openai.com/v1/responses" - logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client") + self._session = None # Lazy initialization for connection pooling + self.reasoning_effort = reasoning_effort # For models that support reasoning effort + logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client with reasoning_effort={reasoning_effort}") + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create the aiohttp session for connection pooling.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def close(self): + """Close the aiohttp session.""" + if self._session and not self._session.closed: + await self._session.close() async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: try: @@ -1049,51 +1089,59 @@ class OpenAIResponsesClient(BaseModelClient): payload = { "model": self.model_name, "input": full_prompt, - "temperature": temperature, - "max_tokens": self.max_tokens, } - if (self.model_name == 'o3' or self.model_name == 'o4-mini'): - del payload["temperature"] - del payload["max_tokens"] - payload["max_completion_tokens"] = self.max_tokens + # The Responses API uses max_output_tokens for all models + payload["max_output_tokens"] = self.max_tokens + + # Only add temperature for models that support it + models_without_temp = ['o3', 'o4-mini', 'gpt-5-reasoning-alpha-2025-07-19', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25'] + if self.model_name not in models_without_temp: + payload["temperature"] = temperature + + # Add reasoning effort for models that support it + reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'o4-mini-alpha-2025-07-11', 'nectarine-alpha-new-reasoning-effort-2025-07-25'] + if self.reasoning_effort and self.model_name in reasoning_models: + payload["reasoning"] = {"effort": self.reasoning_effort} headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - # Make the API call using aiohttp - async with aiohttp.ClientSession() as session: - async with session.post(self.base_url, json=payload, headers=headers) as response: - response.raise_for_status() # Will raise for non-2xx responses - response_data = await response.json() + # Make the API call using the pooled session + session = await self._get_session() + async with session.post(self.base_url, json=payload, headers=headers) as response: + response.raise_for_status() # Will raise for non-2xx responses + response_data = await response.json() - # Extract the text from the nested response structure - try: - outputs = response_data.get("output", []) - if len(outputs) < 2: - raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.") + # Extract the text from the nested response structure + try: + outputs = response_data.get("output", []) + if len(outputs) < 2: + # Log the actual response for debugging + logger.error(f"[{self.model_name}] Response structure: {json.dumps(response_data, indent=2)}") + raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.") - message_output = outputs[1] - if message_output.get("type") != "message": - raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.") + message_output = outputs[1] + if message_output.get("type") != "message": + raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.") - content_list = message_output.get("content", []) - if not content_list: - raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.") + content_list = message_output.get("content", []) + if not content_list: + raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.") - text_content = "" - for content_item in content_list: - if content_item.get("type") == "output_text": - text_content = content_item.get("text", "") - break + text_content = "" + for content_item in content_list: + if content_item.get("type") == "output_text": + text_content = content_item.get("text", "") + break - if not text_content: - raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.") + if not text_content: + raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.") - return text_content.strip() + return text_content.strip() - except (KeyError, IndexError, TypeError) as e: - # Wrap parsing error in a more informative exception - raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e + except (KeyError, IndexError, TypeError) as e: + # Wrap parsing error in a more informative exception + raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e except aiohttp.ClientError as e: logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}") @@ -1302,8 +1350,13 @@ class RequestsOpenAIClient(BaseModelClient): {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"}, ], "temperature": temperature, - "max_tokens": self.max_tokens, } + + # Use max_completion_tokens for o4-mini, o3-mini, o3, gpt-4.1 models and nectarine models + if self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or self.model_name.startswith("nectarine"): + payload["max_completion_tokens"] = self.max_tokens + else: + payload["max_tokens"] = self.max_tokens #if self.model_name == "qwen/qwen3-235b-a22b" and self.base_url == "https://openrouter.ai/api/v1": # payload["provider"] = { @@ -1313,7 +1366,8 @@ class RequestsOpenAIClient(BaseModelClient): if (self.model_name == 'o3' or self.model_name == 'o4-mini'): del payload["temperature"] - del payload["max_tokens"] + if "max_tokens" in payload: + del payload["max_tokens"] payload["max_completion_tokens"] = self.max_tokens loop = asyncio.get_running_loop() @@ -1381,13 +1435,33 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM gpt-4o anthropic:claude-3.7-sonnet openai:llama-3-2-3b@https://localhost:8000#myapikey + gpt-5-reasoning-alpha-2025-07-19:minimal and returns the appropriate client. • If a prefix is omitted the function falls back to the original heuristic mapping exactly as before. - • If an inline API-key (‘#…’) is present it overrides environment vars. + • If an inline API-key ('#…') is present it overrides environment vars. + • For reasoning models, effort can be specified with :minimal, :medium, or :high """ - spec = _parse_model_spec(model_id) + # Extract reasoning effort if present (before general parsing) + reasoning_effort = None + actual_model_id = model_id + + # Check if this is a reasoning model with effort specified + reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25'] + for model in reasoning_models: + if model_id.startswith(model + ':'): + parts = model_id.split(':', 1) + effort_part = parts[1] + # Check if the effort part is valid before treating it as effort + # (it could be a prefix like "openai:") + if effort_part.lower() in ['minimal', 'medium', 'high']: + actual_model_id = parts[0] + reasoning_effort = effort_part.lower() + break + + spec = _parse_model_spec(actual_model_id) + logger.info(f"[load_model_client] Loading client for model_id='{model_id}', parsed spec: prefix={spec.prefix}, model={spec.model}, reasoning_effort={reasoning_effort}") # Inline key overrides env; otherwise fall back as usual *per client* inline_key = spec.key @@ -1421,7 +1495,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM api_key=inline_key, ) case Prefix.OPENAI_RESPONSES: - return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key) + return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort) case Prefix.ANTHROPIC: return ClaudeClient(spec.model, prompts_dir) case Prefix.GEMINI: @@ -1437,27 +1511,41 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM # 2. Heuristic fallback path (identical to the original behaviour) # # ------------------------------------------------------------------ # lower_id = spec.model.lower() + logger.info(f"[load_model_client] Heuristic path: checking model='{spec.model}', lower_id='{lower_id}'") + + # Check if this is a reasoning model that should use Responses API + reasoning_models_requiring_responses = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25'] + if spec.model in reasoning_models_requiring_responses: + logger.info(f"[load_model_client] Selected OpenAIResponsesClient for reasoning model '{spec.model}'") + return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort) if lower_id == "o3-pro": + logger.info(f"[load_model_client] Selected OpenAIResponsesClient for '{spec.model}'") return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key) if spec.model.startswith("together-"): # e.g. "together-mixtral-8x7b" + logger.info(f"[load_model_client] Selected TogetherAIClient for '{spec.model}'") return TogetherAIClient(spec.model.split("together-", 1)[1], prompts_dir) if "openrouter" in lower_id: + logger.info(f"[load_model_client] Selected OpenRouterClient for '{spec.model}'") return OpenRouterClient(spec.model, prompts_dir) if "claude" in lower_id: + logger.info(f"[load_model_client] Selected ClaudeClient for '{spec.model}'") return ClaudeClient(spec.model, prompts_dir) if "gemini" in lower_id: + logger.info(f"[load_model_client] Selected GeminiClient for '{spec.model}'") return GeminiClient(spec.model, prompts_dir) if "deepseek" in lower_id: + logger.info(f"[load_model_client] Selected DeepSeekClient for '{spec.model}'") return DeepSeekClient(spec.model, prompts_dir) # Default: OpenAI-compatible async client + logger.info(f"[load_model_client] No specific match found, using default OpenAIClient for '{spec.model}'") return OpenAIClient( model_name=spec.model, prompts_dir=prompts_dir, diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index fdbb431..df0997d 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -35,19 +35,23 @@ def serialize_agent(agent: DiplomacyAgent) -> dict: } -def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent: +def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None, override_max_tokens: Optional[int] = None) -> DiplomacyAgent: """ Recreates an agent object from a dictionary. If *override_model_id* is provided (e.g. because the CLI argument ``--models`` was used when resuming a game), that model is loaded instead of the one stored in the save file. + + If *override_max_tokens* is provided (e.g. because the CLI argument + ``--max_tokens`` was used when resuming a game), that value is used + instead of the one stored in the save file. """ model_id = override_model_id or agent_data["model_id"] client = load_model_client(model_id, prompts_dir=prompts_dir) - # Keep the original or fallback token limit exactly as before. - client.max_tokens = agent_data.get("max_tokens", 16000) + # Use override if provided, otherwise use saved value, otherwise default to 16000 + client.max_tokens = override_max_tokens or agent_data.get("max_tokens", 16000) agent = DiplomacyAgent( power_name=agent_data["power_name"], @@ -208,9 +212,22 @@ def load_game_state( # --- Rebuild agents ------------------------------------------------------- agents: Dict[str, "DiplomacyAgent"] = {} power_model_map: Dict[str, str] = {} + powers_order = sorted(list(ALL_POWERS)) + + # Parse token limits from run_config + default_max_tokens = run_config.max_tokens if run_config and hasattr(run_config, 'max_tokens') else 16000 + model_max_tokens = {p: default_max_tokens for p in powers_order} + + if run_config and hasattr(run_config, 'max_tokens_per_model') and run_config.max_tokens_per_model: + per_model_values = [s.strip() for s in run_config.max_tokens_per_model.split(",")] + if len(per_model_values) == 7: + for power, token_val_str in zip(powers_order, per_model_values): + model_max_tokens[power] = int(token_val_str) + else: + logger.warning("Expected 7 values for --max_tokens_per_model, using default.") + if run_config and getattr(run_config, "models", None): provided = [m.strip() for m in run_config.models.split(",")] - powers_order = sorted(list(ALL_POWERS)) if len(provided) == len(powers_order): power_model_map = dict(zip(powers_order, provided)) elif len(provided) == 1: @@ -234,6 +251,7 @@ def load_game_state( agent_data, prompts_dir=prompts_dir_from_config, override_model_id=override_id, + override_max_tokens=model_max_tokens.get(power_name), ) # --- Rebuild GameHistory -------------------------------------------------- diff --git a/ai_diplomacy/prompt_constructor.py b/ai_diplomacy/prompt_constructor.py index 471334e..24780ac 100644 --- a/ai_diplomacy/prompt_constructor.py +++ b/ai_diplomacy/prompt_constructor.py @@ -175,16 +175,31 @@ def construct_order_generation_prompt( _ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic # Pick the phase-specific instruction file (using unformatted versions) phase_code = board_state["phase"][-1] # 'M' (movement), 'R', or 'A' / 'B' + + # Determine base instruction file name if phase_code == "M": - instructions_file = get_prompt_path("order_instructions_movement_phase.txt") + base_instruction_file = "order_instructions_movement_phase" elif phase_code in ("A", "B"): # builds / adjustments - instructions_file = get_prompt_path("order_instructions_adjustment_phase.txt") + base_instruction_file = "order_instructions_adjustment_phase" elif phase_code == "R": # retreats - instructions_file = get_prompt_path("order_instructions_retreat_phase.txt") + base_instruction_file = "order_instructions_retreat_phase" else: # unexpected – default to movement rules - instructions_file = get_prompt_path("order_instructions_movement_phase.txt") - - instructions = load_prompt(instructions_file, prompts_dir=prompts_dir) + base_instruction_file = "order_instructions_movement_phase" + + # Check if country-specific prompts are enabled + if config.COUNTRY_SPECIFIC_PROMPTS: + # Try to load country-specific version first + country_specific_file = get_prompt_path(f"{base_instruction_file}_{power_name.lower()}.txt") + instructions = load_prompt(country_specific_file, prompts_dir=prompts_dir) + + # Fall back to generic if country-specific not found + if not instructions: + instructions_file = get_prompt_path(f"{base_instruction_file}.txt") + instructions = load_prompt(instructions_file, prompts_dir=prompts_dir) + else: + # Load generic instruction file + instructions_file = get_prompt_path(f"{base_instruction_file}.txt") + instructions = load_prompt(instructions_file, prompts_dir=prompts_dir) _use_simple = config.SIMPLE_PROMPTS include_order_history = False # defaulting to not include order history in order generation prompt for now diff --git a/config.py b/config.py index 694823f..ba6d538 100644 --- a/config.py +++ b/config.py @@ -11,6 +11,7 @@ class Configuration(BaseSettings): log_file_path: Path | None = None USE_UNFORMATTED_PROMPTS: bool = False SIMPLE_PROMPTS: bool = True + COUNTRY_SPECIFIC_PROMPTS: bool = False # Default models for tasks AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20" diff --git a/lm_game.py b/lm_game.py index b7c0b97..2162057 100644 --- a/lm_game.py +++ b/lm_game.py @@ -173,6 +173,18 @@ def parse_arguments(): "Set to false (0 / false / no) to use original single-step formatted prompts." ), ) + parser.add_argument( + "--country_specific_prompts", + type=_str2bool, + nargs="?", + const=True, + default=False, + help=( + "When true (1 / true / yes) enables country-specific order and conversation prompts. " + "Each power will use their own custom prompts if available (e.g., order_instructions_movement_phase_france.txt). " + "Falls back to generic prompts if country-specific not found." + ), + ) return parser.parse_args() @@ -223,6 +235,14 @@ async def main(): else: config.USE_UNFORMATTED_PROMPTS = False logger.info("Using original single-step formatted prompts") + + # Handle country-specific prompts flag + if args.country_specific_prompts: + config.COUNTRY_SPECIFIC_PROMPTS = True + logger.info("Country-specific prompts enabled - powers will use their custom prompts when available") + else: + config.COUNTRY_SPECIFIC_PROMPTS = False + logger.info("Using generic prompts for all powers") if args.max_year == None: if args.end_at_phase: