diff --git a/README.md b/README.md index bd62266..5859ff9 100644 --- a/README.md +++ b/README.md @@ -236,13 +236,92 @@ python lm_game.py --max_year 1910 --planning_phase --num_negotiation_rounds 2 # Custom model assignment (order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY) python lm_game.py --models "claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022" -# Output to specific file -python lm_game.py --output results/my_game.json - # Run until game completion or specific year python lm_game.py --num_negotiation_rounds 2 --planning_phase + +# Write all artefacts to a chosen directory (auto-resumes if it already exists) +python lm_game.py --run_dir results/game_run_001 + +# Resume an interrupted game from a specific phase +python lm_game.py --run_dir results/game_run_001 --resume_from_phase S1902M + +# Critical-state analysis: resume from an existing run but save new results elsewhere +python lm_game.py \ + --run_dir results/game_run_001 \ + --critical_state_analysis_dir results/critical_analysis_001 \ + --resume_from_phase F1903M + +# End the simulation after a particular phase regardless of remaining years +python lm_game.py --run_dir results/game_run_002 --end_at_phase F1905M + +# Set the global max_tokens generation limit +python lm_game.py --run_dir results/game_run_003 --max_tokens 8000 + +# Per-model token limits (AU,EN,FR,GE,IT,RU,TR) +python lm_game.py --run_dir results/game_run_004 \ + --max_tokens_per_model "8000,8000,16000,8000,8000,16000,8000" + +# Use a custom prompts directory +python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants ``` +### Running Batch Experiments with **`experiment_runner.py`** + +`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify. +All flags that belong to **`lm_game.py`** can be passed straight through; the runner validates them and forwards them unchanged to every game instance. + +--- + +#### Examples + +```bash +# Run 10 independent games (iterations) in parallel, using a custom prompts dir +# and a single model (GPT-4o) for all seven powers. +python3 experiment_runner.py \ + --experiment_dir "results/exp001" \ + --iterations 10 \ + --parallel 10 \ + --max_year 1905 \ + --num_negotiation_rounds 0 \ + --prompts_dir "ai_diplomacy/prompts" \ + --models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o" + + +# Critical-state analysis: resume every run from W1901A (taken from an existing +# base run) and stop after S1902M. Two analysis modules will be executed: +# • summary → aggregated results & scores +# • critical_state → before/after snapshots around the critical phase +python3 experiment_runner.py \ + --experiment_dir "results/exp002" \ + --iterations 10 \ + --parallel 10 \ + --resume_from_phase W1901A \ + --end_at_phase S1902M \ + --num_negotiation_rounds 0 \ + --critical_state_base_run "results/test1" \ + --prompts_dir "ai_diplomacy/prompts" \ + --analysis_modules "summary,critical_state" \ + --models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o" +``` + +*(Any other `lm_game.py` flags—`--planning_phase`, `--max_tokens`, etc.—can be added exactly where you’d use them on a single-game run.)* + +--- + +#### Experiment-runner–specific arguments + +| Flag | Type / Default | Description | +| --------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--experiment_dir` **(required)** | `Path` | Root folder for the experiment; sub-folders `runs/` and `analysis/` are managed automatically. Re-running with the same directory will **resume** existing runs and regenerate analysis. | +| `--iterations` | `int`, default `1` | How many individual games to launch for this experiment. | +| `--parallel` | `int`, default `1` | Max number of games to execute concurrently (uses a process pool). | +| `--analysis_modules` | `str`, default `"summary"` | Comma-separated list of analysis modules to run after all games finish. Modules are imported from `experiment_runner.analysis.` and must provide `run(experiment_dir, ctx)`. | +| `--critical_state_base_run` | `Path`, optional | Path to an **existing** `run_dir` produced by a previous `lm_game` run. Each iteration resumes from that snapshot; new artefacts are written under the current `experiment_dir`. | +| `--seed_base` | `int`, default `42` | Base random seed. Run *ɪ* receives seed = `seed_base + ɪ`, enabling reproducible batches. | + +*(All other command-line flags belong to `lm_game.py` and are forwarded unchanged.)* + + ### Environment Setup Create a `.env` file with your API keys: diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index a2b66a0..00701f4 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -22,13 +22,17 @@ ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUS ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"] # == New: Helper function to load prompt files reliably == -def _load_prompt_file(filename: str) -> Optional[str]: +def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]: """Loads a prompt template from the prompts directory.""" try: - # Construct path relative to this file's location - current_dir = os.path.dirname(os.path.abspath(__file__)) - prompts_dir = os.path.join(current_dir, 'prompts') - filepath = os.path.join(prompts_dir, filename) + if prompts_dir: + filepath = os.path.join(prompts_dir, filename) + else: + # Construct path relative to this file's location + current_dir = os.path.dirname(os.path.abspath(__file__)) + default_prompts_dir = os.path.join(current_dir, 'prompts') + filepath = os.path.join(default_prompts_dir, filename) + with open(filepath, 'r', encoding='utf-8') as f: return f.read() except FileNotFoundError: @@ -50,6 +54,7 @@ class DiplomacyAgent: client: BaseModelClient, initial_goals: Optional[List[str]] = None, initial_relationships: Optional[Dict[str, str]] = None, + prompts_dir: Optional[str] = None, ): """ Initializes the DiplomacyAgent. @@ -60,12 +65,14 @@ class DiplomacyAgent: initial_goals: An optional list of initial strategic goals. initial_relationships: An optional dictionary mapping other power names to relationship statuses (e.g., 'ALLY', 'ENEMY', 'NEUTRAL'). + prompts_dir: Optional path to the prompts directory. """ if power_name not in ALL_POWERS: raise ValueError(f"Invalid power name: {power_name}. Must be one of {ALL_POWERS}") self.power_name: str = power_name self.client: BaseModelClient = client + self.prompts_dir: Optional[str] = prompts_dir # Initialize goals as empty list, will be populated by initialize_agent_state self.goals: List[str] = initial_goals if initial_goals is not None else [] # Initialize relationships to Neutral if not provided @@ -85,16 +92,21 @@ class DiplomacyAgent: # Get the directory containing the current file (agent.py) current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct path relative to the current file's directory - prompts_dir = os.path.join(current_dir, "prompts") - power_prompt_filename = os.path.join(prompts_dir, f"{power_name.lower()}_system_prompt.txt") - default_prompt_filename = os.path.join(prompts_dir, "system_prompt.txt") + default_prompts_path = os.path.join(current_dir, "prompts") + power_prompt_filename = f"{power_name.lower()}_system_prompt.txt" + default_prompt_filename = "system_prompt.txt" - system_prompt_content = load_prompt(power_prompt_filename) + # Use the provided prompts_dir if available, otherwise use the default + prompts_path_to_use = self.prompts_dir if self.prompts_dir else default_prompts_path + + power_prompt_filepath = os.path.join(prompts_path_to_use, power_prompt_filename) + default_prompt_filepath = os.path.join(prompts_path_to_use, default_prompt_filename) + + system_prompt_content = load_prompt(power_prompt_filepath, prompts_dir=self.prompts_dir) if not system_prompt_content: - logger.warning(f"Power-specific prompt '{power_prompt_filename}' not found or empty. Loading default system prompt.") - # system_prompt_content = load_prompt("system_prompt.txt") - system_prompt_content = load_prompt(default_prompt_filename) + logger.warning(f"Power-specific prompt '{power_prompt_filepath}' not found or empty. Loading default system prompt.") + system_prompt_content = load_prompt(default_prompt_filepath, prompts_dir=self.prompts_dir) else: logger.info(f"Loaded power-specific system prompt for {power_name}.") # ---------------------------------------------------- @@ -416,7 +428,7 @@ class DiplomacyAgent: try: # Load the template file but safely preprocess it first - prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt') + prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt', prompts_dir=self.prompts_dir) if not prompt_template_content: logger.error(f"[{self.power_name}] Could not load negotiation_diary_prompt.txt. Skipping diary entry.") success_status = "Failure: Prompt file not loaded" @@ -611,7 +623,7 @@ class DiplomacyAgent: logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...") # Load the template but we'll use it carefully with string interpolation - prompt_template = _load_prompt_file('order_diary_prompt.txt') + prompt_template = _load_prompt_file('order_diary_prompt.txt', prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{self.power_name}] Could not load order_diary_prompt.txt. Skipping diary entry.") return @@ -756,7 +768,7 @@ class DiplomacyAgent: logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...") # Load the template - prompt_template = _load_prompt_file('phase_result_diary_prompt.txt') + prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.") return @@ -859,7 +871,7 @@ class DiplomacyAgent: try: # 1. Construct the prompt using the dedicated state update prompt file - prompt_template = _load_prompt_file('state_update_prompt.txt') + prompt_template = _load_prompt_file('state_update_prompt.txt', prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{power_name}] Could not load state_update_prompt.txt. Skipping state update.") return @@ -893,6 +905,7 @@ class DiplomacyAgent: agent_goals=self.goals, agent_relationships=self.relationships, agent_private_diary=formatted_diary, # Pass formatted diary + prompts_dir=self.prompts_dir, ) # Add previous phase summary to the information provided to the LLM diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index e4154e2..bff7849 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -44,10 +44,11 @@ class BaseModelClient: - get_conversation_reply(power_name, conversation_so_far, game_phase) -> str """ - def __init__(self, model_name: str): + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): self.model_name = model_name + self.prompts_dir = prompts_dir # Load a default initially, can be overwritten by set_system_prompt - self.system_prompt = load_prompt("system_prompt.txt") + self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir) self.max_tokens = 16000 # default unless overridden def set_system_prompt(self, content: str): @@ -97,6 +98,7 @@ class BaseModelClient: agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary_str=agent_private_diary_str, + prompts_dir=self.prompts_dir, ) raw_response = "" @@ -423,7 +425,7 @@ class BaseModelClient: agent_private_diary_str: Optional[str] = None, # Added ) -> str: - instructions = load_prompt("planning_instructions.txt") + instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir) context = self.build_context_prompt( game, @@ -434,6 +436,7 @@ class BaseModelClient: agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary=agent_private_diary_str, # Pass diary string + prompts_dir=self.prompts_dir, ) return context + "\n\n" + instructions @@ -451,7 +454,7 @@ class BaseModelClient: agent_relationships: Optional[Dict[str, str]] = None, agent_private_diary_str: Optional[str] = None, # Added ) -> str: - instructions = load_prompt("conversation_instructions.txt") + instructions = load_prompt("conversation_instructions.txt", prompts_dir=self.prompts_dir) context = build_context_prompt( game, @@ -462,6 +465,7 @@ class BaseModelClient: agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary=agent_private_diary_str, # Pass diary string + prompts_dir=self.prompts_dir, ) # Get recent messages targeting this power to prioritize responses @@ -699,7 +703,7 @@ class BaseModelClient: """ logger.info(f"Client generating strategic plan for {power_name}...") - planning_instructions = load_prompt("planning_instructions.txt") + planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir) if not planning_instructions: logger.error("Could not load planning_instructions.txt! Cannot generate plan.") return "Error: Planning instructions not found." @@ -718,6 +722,7 @@ class BaseModelClient: agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary=agent_private_diary_str, # Pass diary string + prompts_dir=self.prompts_dir, ) full_prompt = f"{context_prompt}\n\n{planning_instructions}" @@ -772,8 +777,8 @@ class OpenAIClient(BaseModelClient): For 'o3-mini', 'gpt-4o', or other OpenAI model calls. """ - def __init__(self, model_name: str): - super().__init__(model_name) + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: @@ -819,8 +824,8 @@ class ClaudeClient(BaseModelClient): For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc. """ - def __init__(self, model_name: str): - super().__init__(model_name) + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: @@ -861,8 +866,8 @@ class GeminiClient(BaseModelClient): For 'gemini-1.5-flash' or other Google Generative AI models. """ - def __init__(self, model_name: str): - super().__init__(model_name) + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) # Configure and get the model (corrected initialization) api_key = os.environ.get("GEMINI_API_KEY") if not api_key: @@ -905,8 +910,8 @@ class DeepSeekClient(BaseModelClient): For DeepSeek R1 'deepseek-reasoner' """ - def __init__(self, model_name: str): - super().__init__(model_name) + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) self.api_key = os.environ.get("DEEPSEEK_API_KEY") self.client = AsyncDeepSeekOpenAI( api_key=self.api_key, @@ -961,8 +966,8 @@ class OpenAIResponsesClient(BaseModelClient): This client makes direct HTTP requests to the v1/responses endpoint. """ - def __init__(self, model_name: str): - super().__init__(model_name) + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) self.api_key = os.environ.get("OPENAI_API_KEY") if not self.api_key: raise ValueError("OPENAI_API_KEY environment variable is required") @@ -1068,14 +1073,14 @@ class OpenRouterClient(BaseModelClient): For OpenRouter models, with default being 'openrouter/quasar-alpha' """ - def __init__(self, model_name: str = "openrouter/quasar-alpha"): + def __init__(self, model_name: str = "openrouter/quasar-alpha", prompts_dir: Optional[str] = None): # Allow specifying just the model identifier or the full path if not model_name.startswith("openrouter/") and "/" not in model_name: model_name = f"openrouter/{model_name}" if model_name.startswith("openrouter-"): model_name = model_name.replace("openrouter-", "") - super().__init__(model_name) + super().__init__(model_name, prompts_dir=prompts_dir) self.api_key = os.environ.get("OPENROUTER_API_KEY") if not self.api_key: raise ValueError("OPENROUTER_API_KEY environment variable is required") @@ -1146,8 +1151,8 @@ class TogetherAIClient(BaseModelClient): Model names should be passed without the 'together-' prefix. """ - def __init__(self, model_name: str): - super().__init__(model_name) # model_name here is the actual Together AI model identifier + def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + super().__init__(model_name, prompts_dir=prompts_dir) # model_name here is the actual Together AI model identifier self.api_key = os.environ.get("TOGETHER_API_KEY") if not self.api_key: raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient") @@ -1198,12 +1203,13 @@ class TogetherAIClient(BaseModelClient): ############################################################################## -def load_model_client(model_id: str) -> BaseModelClient: +def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient: """ Returns the appropriate LLM client for a given model_id string. Args: model_id: The model identifier + prompts_dir: Optional path to the prompts directory. Example usage: client = load_model_client("claude-3-5-sonnet-20241022") @@ -1213,23 +1219,23 @@ def load_model_client(model_id: str) -> BaseModelClient: # Check for o3-pro model specifically - it needs the Responses API if lower_id == "o3-pro": - return OpenAIResponsesClient(model_id) + return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir) # Check for OpenRouter first to handle prefixed models like openrouter-deepseek elif model_id.startswith("together-"): actual_model_name = model_id.split("together-", 1)[1] logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})") - return TogetherAIClient(actual_model_name) + return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir) elif "openrouter" in model_id.lower() or "/" in model_id: # More general check for OpenRouterClient(model_id) - return OpenRouterClient(model_id) + return OpenRouterClient(model_id, prompts_dir=prompts_dir) elif "claude" in lower_id: - return ClaudeClient(model_id) + return ClaudeClient(model_id, prompts_dir=prompts_dir) elif "gemini" in lower_id: - return GeminiClient(model_id) + return GeminiClient(model_id, prompts_dir=prompts_dir) elif "deepseek" in lower_id: - return DeepSeekClient(model_id) + return DeepSeekClient(model_id, prompts_dir=prompts_dir) else: # Default to OpenAI (for models like o3-mini, gpt-4o, etc.) - return OpenAIClient(model_id) + return OpenAIClient(model_id, prompts_dir=prompts_dir) ############################################################################## @@ -1249,4 +1255,4 @@ def get_visible_messages_for_power(conversation_messages, power_name): or msg["recipient"] == power_name ): visible.append(msg) - return visible # already in chronological order if appended that way \ No newline at end of file + return visible # already in chronological order if appended that way diff --git a/ai_diplomacy/diary_logic.py b/ai_diplomacy/diary_logic.py index 4ae814e..1dbe113 100644 --- a/ai_diplomacy/diary_logic.py +++ b/ai_diplomacy/diary_logic.py @@ -1,7 +1,7 @@ # ai_diplomacy/diary_logic.py import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from .utils import run_llm_and_log, log_llm_response @@ -11,13 +11,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _load_prompt_file(filename: str) -> str | None: +def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None: """A local copy of the helper from agent.py to avoid circular imports.""" import os try: - current_dir = os.path.dirname(os.path.abspath(__file__)) - prompts_dir = os.path.join(current_dir, 'prompts') - filepath = os.path.join(prompts_dir, filename) + if prompts_dir: + filepath = os.path.join(prompts_dir, filename) + else: + current_dir = os.path.dirname(os.path.abspath(__file__)) + default_prompts_dir = os.path.join(current_dir, 'prompts') + filepath = os.path.join(default_prompts_dir, filename) + with open(filepath, 'r', encoding='utf-8') as f: return f.read() except Exception as e: @@ -29,6 +33,7 @@ async def run_diary_consolidation( game: "Game", log_file_path: str, entries_to_keep_unsummarized: int = 15, + prompts_dir: Optional[str] = None, ): """ Consolidate older diary entries while keeping recent ones. @@ -93,7 +98,7 @@ async def run_diary_consolidation( ) return - prompt_template = _load_prompt_file("diary_consolidation_prompt.txt") + prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) if not prompt_template: logger.error( f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting" diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index 641186d..4132782 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -31,16 +31,17 @@ def serialize_agent(agent: DiplomacyAgent) -> dict: "private_diary": agent.private_diary, } -def deserialize_agent(agent_data: dict) -> DiplomacyAgent: +def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent: """Recreates an agent object from a dictionary.""" - client = load_model_client(agent_data["model_id"]) + client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir) client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves agent = DiplomacyAgent( power_name=agent_data["power_name"], client=client, initial_goals=agent_data.get("goals", []), - initial_relationships=agent_data.get("relationships", None) + initial_relationships=agent_data.get("relationships", None), + prompts_dir=prompts_dir ) # Restore the diary. agent.full_private_diary = agent_data.get("full_private_diary", []) @@ -179,7 +180,7 @@ def load_game_state( run_dir: str, game_file_name: str, resume_from_phase: Optional[str] = None -) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Namespace]: +) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]: """Loads and reconstructs the game state from a saved game file.""" game_file_path = os.path.join(run_dir, game_file_name) if not os.path.exists(game_file_path): @@ -189,6 +190,14 @@ def load_game_state( with open(game_file_path, 'r') as f: saved_game_data = json.load(f) + # Find the latest config saved in the file + run_config = None + if saved_game_data.get("phases"): + for phase in reversed(saved_game_data["phases"]): + if "config" in phase: + run_config = Namespace(**phase["config"]) + logger.info(f"Loaded run configuration from phase {phase['name']}.") + break # If resuming, find the specified phase and truncate the data after it if resume_from_phase: @@ -231,8 +240,6 @@ def load_game_state( logger.info("No previous phases found. Initializing fresh agents and history.") agents = {} # Will be created by the main loop game_history = GameHistory() - # No config to load, will use current run's args - run_config = None else: # We save the game state up to & including the current (uncompleted) phase. # So we need to grab the agent state from the previous (completed) phase. @@ -245,8 +252,9 @@ def load_game_state( agents = {} if 'state_agents' in last_phase_data: logger.info("Rebuilding agents from saved state...") + prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None for power_name, agent_data in last_phase_data['state_agents'].items(): - agents[power_name] = deserialize_agent(agent_data) + agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config) logger.info(f"Rebuilt {len(agents)} agents.") else: raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.") @@ -267,7 +275,7 @@ def load_game_state( logger.info("Game history rebuilt.") - return game, agents, game_history + return game, agents, game_history, run_config async def initialize_new_game( @@ -308,12 +316,12 @@ async def initialize_new_game( for power_name, model_id in game.power_model_map.items(): if not game.powers[power_name].is_eliminated(): try: - client = load_model_client(model_id) + client = load_model_client(model_id, prompts_dir=args.prompts_dir) client.max_tokens = model_max_tokens[power_name] - agent = DiplomacyAgent(power_name=power_name, client=client) + agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir) agents[power_name] = agent logger.info(f"Preparing initialization task for {power_name} with model {model_id}") - initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path)) + initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir)) except Exception as e: logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True) diff --git a/ai_diplomacy/initialization.py b/ai_diplomacy/initialization.py index e69add9..3c84ae4 100644 --- a/ai_diplomacy/initialization.py +++ b/ai_diplomacy/initialization.py @@ -1,6 +1,7 @@ # ai_diplomacy/initialization.py import logging import json +from typing import Optional # Forward declaration for type hinting, actual imports in function if complex if False: # TYPE_CHECKING @@ -18,7 +19,8 @@ async def initialize_agent_state_ext( agent: 'DiplomacyAgent', game: 'Game', game_history: 'GameHistory', - log_file_path: str + log_file_path: str, + prompts_dir: Optional[str] = None, ): """Uses the LLM to set initial goals and relationships for the agent.""" power_name = agent.power_name @@ -56,7 +58,8 @@ async def initialize_agent_state_ext( game_history=game_history, agent_goals=None, agent_relationships=None, - agent_private_diary=formatted_diary, + agent_private_diary=formatted_diary, + prompts_dir=prompts_dir, ) full_prompt = initial_prompt + "\n\n" + context diff --git a/ai_diplomacy/prompt_constructor.py b/ai_diplomacy/prompt_constructor.py index 5666333..ea6ff2c 100644 --- a/ai_diplomacy/prompt_constructor.py +++ b/ai_diplomacy/prompt_constructor.py @@ -23,6 +23,7 @@ def build_context_prompt( agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, agent_private_diary: Optional[str] = None, + prompts_dir: Optional[str] = None, ) -> str: """Builds the detailed context part of the prompt. @@ -35,11 +36,12 @@ def build_context_prompt( agent_goals: Optional list of agent's goals. agent_relationships: Optional dictionary of agent's relationships with other powers. agent_private_diary: Optional string of agent's private diary. + prompts_dir: Optional path to the prompts directory. Returns: A string containing the formatted context. """ - context_template = load_prompt("context_prompt.txt") + context_template = load_prompt("context_prompt.txt", prompts_dir=prompts_dir) # === Agent State Debug Logging === if agent_goals: @@ -112,6 +114,7 @@ def construct_order_generation_prompt( agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, agent_private_diary_str: Optional[str] = None, + prompts_dir: Optional[str] = None, ) -> str: """Constructs the final prompt for order generation. @@ -125,13 +128,14 @@ def construct_order_generation_prompt( agent_goals: Optional list of agent's goals. agent_relationships: Optional dictionary of agent's relationships with other powers. agent_private_diary_str: Optional string of agent's private diary. + prompts_dir: Optional path to the prompts directory. Returns: A string containing the complete prompt for the LLM. """ # Load prompts - _ = load_prompt("few_shot_example.txt") # Loaded but not used, as per original logic - instructions = load_prompt("order_instructions.txt") + _ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic + instructions = load_prompt("order_instructions.txt", prompts_dir=prompts_dir) # Build the context prompt context = build_context_prompt( @@ -143,7 +147,8 @@ def construct_order_generation_prompt( agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary=agent_private_diary_str, + prompts_dir=prompts_dir, ) final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions - return final_prompt + return final_prompt \ No newline at end of file diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index 2a71714..5229dca 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -293,11 +293,14 @@ def normalize_and_compare_orders( # Helper to load prompt text from file relative to the expected 'prompts' dir -def load_prompt(filename: str) -> str: +def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str: """Helper to load prompt text from file""" - # Assuming execution from the root or that the path resolves correctly - # Consider using absolute paths or pkg_resources if needed for robustness - prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename) + if prompts_dir: + prompt_path = os.path.join(prompts_dir, filename) + else: + # Default behavior: relative to this file's location in the 'prompts' subdir + prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename) + try: with open(prompt_path, "r", encoding='utf-8') as f: # Added encoding return f.read().strip() diff --git a/experiment_runner.py b/experiment_runner.py new file mode 100644 index 0000000..f56f185 --- /dev/null +++ b/experiment_runner.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 +""" +Experiment orchestration for Diplomacy self-play. +Launches many `lm_game` runs in parallel, captures their artefacts, +and executes a pluggable post-analysis pipeline. + +Run `python experiment_runner.py --help` for CLI details. +""" +from __future__ import annotations + +import argparse +import collections +import concurrent.futures +import importlib +import json +import logging +import math +import os +import shutil +import subprocess +import sys +import textwrap +import time +import multiprocessing as mp +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace +from typing import Iterable, List + +# --------------------------------------------------------------------------- # +# Logging # +# --------------------------------------------------------------------------- # +LOG_FMT = "%(asctime)s [%(levelname)s] %(name)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%H:%M:%S") +log = logging.getLogger("experiment_runner") + + + +# ──────────────────────────────────────────────────────────────────────────── +# Flag definitions – full, un-shortened help strings # +# ──────────────────────────────────────────────────────────────────────────── +def _add_experiment_flags(p: argparse.ArgumentParser) -> None: + p.add_argument( + "--experiment_dir", + type=Path, + required=True, + help=( + "Directory that will hold all experiment artefacts. " + "A 'runs/' sub-folder is created for individual game runs and an " + "'analysis/' folder for aggregated outputs. Must be writable." + ), + ) + p.add_argument( + "--iterations", + type=int, + default=1, + help=( + "Number of lm_game instances to launch for this experiment. " + "Each instance gets its own sub-directory under runs/." + ), + ) + p.add_argument( + "--parallel", + type=int, + default=1, + help=( + "Maximum number of game instances to run concurrently. " + "Uses a ProcessPoolExecutor under the hood." + ), + ) + p.add_argument( + "--analysis_modules", + type=str, + default="summary", + help=( + "Comma-separated list of analysis module names to execute after all " + "runs finish. Modules are imported from " + "'experiment_runner.analysis.' and must expose " + "run(experiment_dir: Path, ctx: dict)." + ), + ) + p.add_argument( + "--critical_state_base_run", + type=Path, + default=None, + help=( + "Path to an *existing* run directory produced by a previous lm_game " + "execution. When supplied, every iteration resumes from that " + "snapshot using lm_game's --critical_state_analysis_dir mechanism." + ), + ) + p.add_argument( + "--seed_base", + type=int, + default=42, + help=( + "Base RNG seed. Run i will receive seed = seed_base + i. " + "Forwarded to lm_game via its --seed flag (you must have added that " + "flag to lm_game for deterministic behaviour)." + ), + ) + + +def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: + # ---- all flags copied verbatim from lm_game.parse_arguments() ---- + p.add_argument( + "--resume_from_phase", + type=str, + default="", + help=( + "Phase to resume from (e.g., 'S1902M'). Requires --run_dir. " + "IMPORTANT: This option clears any existing phase results ahead of " + "& including the specified resume phase." + ), + ) + p.add_argument( + "--end_at_phase", + type=str, + default="", + help="Phase to end the simulation after (e.g., 'F1905M').", + ) + p.add_argument( + "--max_year", + type=int, + default=1910, # Increased default in lm_game + help="Maximum year to simulate. The game will stop once this year is reached.", + ) + p.add_argument( + "--num_negotiation_rounds", + type=int, + default=0, + help="Number of negotiation rounds per phase.", + ) + p.add_argument( + "--models", + type=str, + default="", + help=( + "Comma-separated list of model names to assign to powers in order. " + "The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY." + ), + ) + p.add_argument( + "--planning_phase", + action="store_true", + help="Enable the planning phase for each power to set strategic directives.", + ) + p.add_argument( + "--max_tokens", + type=int, + default=16000, + help="Maximum number of new tokens to generate per LLM call (default: 16000).", + ) + p.add_argument( + "--max_tokens_per_model", + type=str, + default="", + help=( + "Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, " + "FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens." + ), + ) + p.add_argument( + "--prompts_dir", + type=str, + default=None, + help=( + "Path to the directory containing prompt files. " + "Defaults to the packaged prompts directory." + ), + ) + + +# ──────────────────────────────────────────────────────────────────────────── +# One combined parser for banner printing # +# ──────────────────────────────────────────────────────────────────────────── +def _build_full_parser() -> argparse.ArgumentParser: + fp = argparse.ArgumentParser( + prog="experiment_runner.py", + formatter_class=lambda prog: argparse.RawTextHelpFormatter( + prog, max_help_position=45 + ), + description=( + "Batch-runner for Diplomacy self-play experiments. " + "All lm_game flags are accepted here as-is; they are validated " + "before any game runs start." + ), + ) + _add_experiment_flags(fp) + _add_lm_game_flags(fp) + return fp + + +# ──────────────────────────────────────────────────────────────────────────── +# Robust parsing that always shows *full* help on error # +# ──────────────────────────────────────────────────────────────────────────── +def _parse_cli() -> tuple[argparse.Namespace, list[str], argparse.Namespace]: + full_parser = _build_full_parser() + + # Show full banner when no args + if len(sys.argv) == 1: + full_parser.print_help(sys.stderr) + sys.exit(2) + + # Show full banner on explicit help + if any(tok in ("-h", "--help") for tok in sys.argv[1:]): + full_parser.print_help(sys.stderr) + sys.exit(0) + + # Sub-parsers for separating experiment vs game flags + class _ErrParser(argparse.ArgumentParser): + def error(self, msg): + full_parser.print_help(sys.stderr) + self.exit(2, f"{self.prog}: error: {msg}\n") + + exp_parser = _ErrParser(add_help=False) + game_parser = _ErrParser(add_help=False) + _add_experiment_flags(exp_parser) + _add_lm_game_flags(game_parser) + + # Split argv tokens by flag ownership + argv = sys.argv[1:] + exp_flag_set = {opt for a in exp_parser._actions for opt in a.option_strings} + + exp_tok, game_tok, i = [], [], 0 + while i < len(argv): + tok = argv[i] + if tok in exp_flag_set: + exp_tok.append(tok) + action = exp_parser._option_string_actions[tok] + needs_val = ( + action.nargs is None # default: exactly one value + or (isinstance(action.nargs, int) and action.nargs > 0) + or action.nargs in ("+", "*", "?") # variable-length cases + ) + if needs_val: + exp_tok.append(argv[i + 1]) + i += 2 + else: # store_true / store_false + i += 1 + + else: + game_tok.append(tok) + i += 1 + + exp_args = exp_parser.parse_args(exp_tok) + game_args = game_parser.parse_args(game_tok) + return exp_args, game_tok, game_args + + +# --------------------------------------------------------------------------- # +# Helpers # +# --------------------------------------------------------------------------- # +_RunInfo = collections.namedtuple( + "_RunInfo", "index run_dir seed cmd_line returncode elapsed_s" +) + + +def _mk_run_dir(exp_dir: Path, idx: int) -> Path: + run_dir = exp_dir / "runs" / f"run_{idx:05d}" + # Just ensure it exists; don't raise if it already does. + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +def _dump_seed(seed: int, run_dir: Path) -> None: + seed_file = run_dir / "seed.txt" + if not seed_file.exists(): + seed_file.write_text(str(seed)) + + +def _build_cmd( + lm_game_script: Path, + base_cli: List[str], + run_dir: Path, + seed: int, + critical_base: Path | None, + resume_from_phase: str, +) -> List[str]: + """ + Returns a list suitable for subprocess.run([...]). + """ + cmd = [sys.executable, str(lm_game_script)] + + # Forward user CLI + cmd.extend(base_cli) + + # Per-run mandatory overrides + cmd.extend(["--run_dir", str(run_dir)]) + cmd.extend(["--seed", str(seed)]) # you may need to add a --seed flag to lm_game + + # Critical-state mode + if critical_base: + cmd.extend([ + "--critical_state_analysis_dir", str(run_dir), + "--run_dir", str(critical_base) # base run dir (already completed) + ]) + if resume_from_phase: + cmd.extend(["--resume_from_phase", resume_from_phase]) + + return cmd + + +def _launch_one(args) -> _RunInfo: + """ + Worker executed inside a ProcessPool; runs one game via subprocess. + """ + ( + idx, + lm_game_script, + base_cli, + run_dir, + seed, + critical_base, + resume_phase, + ) = args + + cmd = _build_cmd( + lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase + ) + start = time.perf_counter() + log.debug("Run %05d: CMD = %s", idx, " ".join(cmd)) + + # Write out full command for traceability + (run_dir / "command.txt").write_text(" ".join(cmd)) + + try: + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + (run_dir / "console.log").write_text(result.stdout) + rc = result.returncode + except Exception as exc: # noqa: broad-except + (run_dir / "console.log").write_text(f"Exception launching run:\n{exc}\n") + rc = 1 + + elapsed = time.perf_counter() - start + return _RunInfo(idx, run_dir, seed, " ".join(cmd), rc, elapsed) + + +def _load_analysis_fns(module_names: Iterable[str]): + """ + Dynamically import analysis modules. + Each module must expose `run(experiment_dir: Path, cfg: dict)`. + """ + for name in module_names: + mod_name = f"experiment_runner.analysis.{name.strip()}" + try: + mod = importlib.import_module(mod_name) + except ModuleNotFoundError as e: + log.warning("Analysis module %s not found (%s) – skipping", mod_name, e) + continue + + if not hasattr(mod, "run"): + log.warning("%s has no `run()` function – skipping", mod_name) + continue + yield mod.run + + +# --------------------------------------------------------------------------- # +# Main driver # +# --------------------------------------------------------------------------- # +def main() -> None: + exp_args, leftover_cli, game_args = _parse_cli() + + exp_dir: Path = exp_args.experiment_dir.expanduser().resolve() + if exp_dir.exists(): + log.info("Appending to existing experiment: %s", exp_dir) + exp_dir.mkdir(parents=True, exist_ok=True) + + # Persist experiment-level config + cfg_path = exp_dir / "config.json" + if not cfg_path.exists(): # ← new guard + with cfg_path.open("w", encoding="utf-8") as fh: + json.dump( + {"experiment": vars(exp_args), + "lm_game": vars(game_args), + "forwarded_cli": leftover_cli}, + fh, indent=2, default=str, + ) + log.info("Config saved to %s", cfg_path) + else: + log.info("Config already exists – leaving unchanged") + + + log.info("Config saved to %s", cfg_path) + + # ------------------------------------------------------------------ # + # Launch games # + # ------------------------------------------------------------------ # + lm_game_script = Path(__file__).parent / "lm_game.py" + if not lm_game_script.exists(): + log.error("lm_game.py not found at %s – abort", lm_game_script) + sys.exit(1) + + run_args = [] + for i in range(exp_args.iterations): + run_dir = _mk_run_dir(exp_dir, i) + seed = exp_args.seed_base + i + _dump_seed(seed, run_dir) + + run_args.append( + ( + i, lm_game_script, leftover_cli, run_dir, seed, + exp_args.critical_state_base_run, + game_args.resume_from_phase, + ) + ) + + + log.info( + "Launching %d runs (max %d parallel, critical_state=%s)", + exp_args.iterations, + exp_args.parallel, + bool(exp_args.critical_state_base_run), + ) + + runs_meta: list[_RunInfo] = [] + with concurrent.futures.ProcessPoolExecutor( + max_workers=exp_args.parallel, + mp_context=mp.get_context("spawn"), + ) as pool: + for res in pool.map(_launch_one, run_args): + runs_meta.append(res) + status = "OK" if res.returncode == 0 else f"RC={res.returncode}" + log.info( + "run_%05d finished in %.1fs %s", res.index, res.elapsed_s, status + ) + + # Persist per-run status summary + summary_path = exp_dir / "runs_summary.json" + with open(summary_path, "w", encoding="utf-8") as fh: + json.dump([res._asdict() for res in runs_meta], fh, indent=2, default=str) + log.info("Run summary written → %s", summary_path) + + # ------------------------------------------------------------------ # + # Post-analysis pipeline # + # ------------------------------------------------------------------ # + mods = list(_load_analysis_fns(exp_args.analysis_modules.split(","))) + if not mods: + log.warning("No analysis modules loaded – done.") + return + + analysis_root = exp_dir / "analysis" + if analysis_root.exists(): + shutil.rmtree(analysis_root) # ← wipes old outputs + analysis_root.mkdir(exist_ok=True) + + # Collect common context + ctx: dict = { + "exp_dir": str(exp_dir), + "runs": [str(r.run_dir) for r in runs_meta], + "critical_state_base": str(exp_args.critical_state_base_run or ""), + "resume_from_phase": game_args.resume_from_phase, + } + + for fn in mods: + name = fn.__module__.rsplit(".", 1)[-1] + log.info("Running analysis module: %s", name) + try: + fn(exp_dir, ctx) + log.info("✓ %s complete", name) + except Exception as exc: # noqa: broad-except + log.exception("Analysis module %s failed: %s", name, exc) + + log.info("Experiment finished – artefacts in %s", exp_dir) + + +if __name__ == "__main__": + main() diff --git a/experiment_runner/__init__.py b/experiment_runner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiment_runner/analysis/__init__.py b/experiment_runner/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiment_runner/analysis/critical_state.py b/experiment_runner/analysis/critical_state.py new file mode 100644 index 0000000..771de4c --- /dev/null +++ b/experiment_runner/analysis/critical_state.py @@ -0,0 +1,60 @@ +""" +Extracts the board state *before* and *after* a critical phase +for every critical-analysis run produced by experiment_runner. + +Each run must have: +• lmvsgame.json – containing a phase named ctx["resume_from_phase"] + +Outputs live in analysis/critical_state/_{before|after}.json +""" +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Dict, List + +log = logging.getLogger(__name__) + + +def _phase_by_name(phases: List[dict], name: str) -> dict | None: + for ph in phases: + if ph["state"]["name"] == name: + return ph + return None + + +def run(exp_dir: Path, ctx: Dict) -> None: + resume_phase = ctx.get("resume_from_phase") + if not resume_phase: + log.info("critical_state: --resume_from_phase not supplied – skipping") + return + + out_dir = exp_dir / "analysis" / "critical_state" + out_dir.mkdir(parents=True, exist_ok=True) + + for run_dir in (exp_dir / "runs").iterdir(): + game_json = run_dir / "lmvsgame.json" + if not game_json.exists(): + continue + + with game_json.open("r") as fh: + game = json.load(fh) + phases = game.get("phases", []) + if not phases: + continue + + before = _phase_by_name(phases, resume_phase) + after = phases[-1] if phases else None + if before is None or after is None: + log.warning("Run %s missing expected phases – skipped", run_dir.name) + continue + + (out_dir / f"{run_dir.name}_before.json").write_text( + json.dumps(before, indent=2) + ) + (out_dir / f"{run_dir.name}_after.json").write_text( + json.dumps(after, indent=2) + ) + + log.info("critical_state: snapshots written to %s", out_dir) diff --git a/experiment_runner/analysis/summary.py b/experiment_runner/analysis/summary.py new file mode 100644 index 0000000..e26a687 --- /dev/null +++ b/experiment_runner/analysis/summary.py @@ -0,0 +1,177 @@ +""" +Aggregates results across all runs and writes: + +• analysis/aggregated_results.csv +• analysis/score_summary_by_power.csv +• analysis/results_summary.png (if matplotlib available) +""" +from __future__ import annotations + +import json +import logging +import re +from pathlib import Path +from typing import Any, Dict, List + +import pandas as pd + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- # +# Helpers copied (and lightly cleaned) from run_games.py # +# --------------------------------------------------------------------------- # +def _extract_diplomacy_results(game_json: Path) -> Dict[str, Dict[str, Any]]: + with game_json.open("r", encoding="utf-8") as fh: + gd = json.load(fh) + + phases = gd.get("phases", []) + if not phases: + raise ValueError("no phases") + + first_state = phases[0]["state"] + last_state = phases[-1]["state"] + + powers = ( + list(first_state.get("homes", {}).keys()) + or list(first_state.get("centers", {}).keys()) + ) + if not powers: + raise ValueError("cannot determine powers") + + scs_to_win = 18 + solo_winner = next( + (p for p, sc in last_state["centers"].items() if len(sc) >= scs_to_win), None + ) + + results: Dict[str, Dict[str, Any]] = {} + for p in powers: + sc_count = len(last_state["centers"].get(p, [])) + units = len(last_state["units"].get(p, [])) + + if solo_winner: + if p == solo_winner: + outcome, cat = f"Won solo", "Solo Win" + else: + outcome, cat = f"Lost to {solo_winner}", "Loss" + elif sc_count == 0 and units == 0: + outcome, cat = "Eliminated", "Eliminated" + else: + outcome, cat = "Ongoing/Draw", "Ongoing/Abandoned/Draw" + + results[p] = { + "OutcomeCategory": cat, + "StatusDetail": outcome, + "SupplyCenters": sc_count, + "LastPhase": last_state["name"], + } + return results + + +# simplistic "Diplobench" scoring from previous discussion ------------------ # +def _year(name: str) -> int | None: + m = re.search(r"(\d{4})", name) + return int(m.group(1)) if m else None + + +def _score_game(game_json: Path) -> Dict[str, int]: + with open(game_json, "r") as fh: + game = json.load(fh) + phases = game["phases"] + if not phases: + return {} + + start = _year(phases[0]["state"]["name"]) + end_year = _year(phases[-1]["state"]["name"]) or start + max_turns = (end_year - start + 1) if start is not None else len(phases) + + last_state = phases[-1]["state"] + solo_winner = next( + (p for p, scs in last_state["centers"].items() if len(scs) >= 18), None + ) + + elim_turn: Dict[str, int | None] = {} + for p in last_state["centers"].keys(): + e_turn = None + for idx, ph in enumerate(phases): + if not ph["state"]["centers"].get(p, []): + yr = _year(ph["state"]["name"]) or 0 + e_turn = (yr - start + 1) if start is not None else idx + 1 + break + elim_turn[p] = e_turn + + scores: Dict[str, int] = {} + for p, scs in last_state["centers"].items(): + if p == solo_winner: + win_turn = (end_year - start + 1) if start is not None else len(phases) + scores[p] = max_turns + 17 + (max_turns - win_turn) + elif solo_winner: + # losers in a solo game + yr_win = _year(phases[-1]["state"]["name"]) or end_year + turn_win = (yr_win - start + 1) if start is not None else len(phases) + scores[p] = turn_win + else: + if elim_turn[p] is None: + scores[p] = max_turns + len(scs) + else: + scores[p] = elim_turn[p] + return scores + + +# --------------------------------------------------------------------------- # +# Public entry point # +# --------------------------------------------------------------------------- # +def run(exp_dir: Path, ctx: dict): # pylint: disable=unused-argument + analysis_dir = exp_dir / "analysis" + analysis_dir.mkdir(exist_ok=True) + + rows: List[Dict[str, Any]] = [] + for run_dir in (exp_dir / "runs").iterdir(): + game_json = run_dir / "lmvsgame.json" + if not game_json.exists(): + continue + + gid = run_dir.name + try: + res = _extract_diplomacy_results(game_json) + except Exception as e: # noqa: broad-except + log.warning("Could not parse %s (%s)", game_json, e) + continue + + scores = _score_game(game_json) + for pwr, info in res.items(): + out = {"GameID": gid, "Power": pwr, **info, "Score": scores.get(pwr, None)} + rows.append(out) + + if not rows: + log.warning("summary: no parsable runs found") + return + + df = pd.DataFrame(rows) + out_csv = analysis_dir / "aggregated_results.csv" + df.to_csv(out_csv, index=False) + + summary = ( + df.groupby("Power")["Score"] + .agg(["mean", "median", "count"]) + .reset_index() + .rename(columns={"count": "n"}) + ) + summary.to_csv(analysis_dir / "score_summary_by_power.csv", index=False) + + log.info("summary: wrote %s rows to %s", len(df), out_csv) + + # Optional charts + try: + import matplotlib.pyplot as plt + import seaborn as sns + + sns.set_style("whitegrid") + plt.figure(figsize=(10, 7)) + sns.boxplot(x="Power", y="SupplyCenters", data=df, palette="pastel") + plt.title("Supply-center distribution") + plt.savefig(analysis_dir / "results_summary.png", dpi=150) + plt.close() + log.info("summary: chart saved") + except Exception as e: # noqa: broad-except + log.debug("Chart generation skipped (%s)", e) diff --git a/lm_game.py b/lm_game.py index ed1e935..5da9678 100644 --- a/lm_game.py +++ b/lm_game.py @@ -55,6 +55,12 @@ def parse_arguments(): default="", help="Directory for results. If it exists, the game resumes. If not, it's created. Defaults to a new timestamped directory.", ) + parser.add_argument( + "--output", # alias for back compatibility + dest="run_dir", # write to the same variable as --run_dir + type=str, + help=argparse.SUPPRESS # hides it from `--help` + ) parser.add_argument( "--critical_state_analysis_dir", type=str, @@ -105,12 +111,24 @@ def parse_arguments(): default=16000, help="Maximum number of new tokens to generate per LLM call (default: 16000)." ) + parser.add_argument( + "--seed_base", + type=int, + default=42, + help="RNG seed placeholder for compatibility with experiment_runner. Currently unused." + ) parser.add_argument( "--max_tokens_per_model", type=str, default="", help="Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens." ) + parser.add_argument( + "--prompts_dir", + type=str, + default=None, + help="Path to the directory containing prompt files. Defaults to the packaged prompts directory." + ) return parser.parse_args() @@ -179,8 +197,17 @@ async def main(): if is_resuming: try: # When resuming, we load the state and also the config from the last saved phase. - # We will continue to use THAT config, ignoring the current script's args. - game, agents, game_history = load_game_state(run_dir, game_file_name, args.resume_from_phase) + game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, args.resume_from_phase) + if loaded_run_config: + # Use the saved config, but allow current CLI args to override control-flow parameters + run_config = loaded_run_config + run_config.run_dir = args.run_dir + run_config.critical_state_analysis_dir = args.critical_state_analysis_dir + run_config.resume_from_phase = args.resume_from_phase + run_config.end_at_phase = args.end_at_phase + # If prompts_dir is specified now, it overrides the saved one. + if args.prompts_dir is not None: + run_config.prompts_dir = args.prompts_dir logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.") except (FileNotFoundError, ValueError) as e: logger.error(f"Could not resume game: {e}. Starting a new game instead.") @@ -204,7 +231,7 @@ async def main(): if year_int > run_config.max_year: logger.info(f"Reached max year {run_config.max_year}, stopping simulation.") break - if run_config.end_at_phase and game.phases and game.phases[-1].name == run_config.end_at_phase: + if run_config.end_at_phase and current_phase == run_config.end_at_phase: logger.info(f"Reached end phase {run_config.end_at_phase}, stopping simulation.") break @@ -289,7 +316,7 @@ async def main(): # Diary Consolidation if current_short_phase.startswith("S") and current_short_phase.endswith("M"): consolidation_tasks = [ - run_diary_consolidation(agent, game, llm_log_file_path) + run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir) for agent in agents.values() if not game.powers[agent.power_name].is_eliminated() ] if consolidation_tasks: diff --git a/requirements.txt b/requirements.txt index c544085..9e8eb8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,5 @@ anthropic google-genai json-repair together +matplotlib +seaborn \ No newline at end of file diff --git a/run.sh b/run.sh index c62769d..1eb4d17 100755 --- a/run.sh +++ b/run.sh @@ -2,6 +2,7 @@ python3 lm_game.py \ --max_year 1901 \ - --num_negotiation_rounds 0 \ + --num_negotiation_rounds 1 \ --models "openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17" \ - --max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000 \ No newline at end of file + --max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000 \ + --prompts_dir "ai_diplomacy/prompts" \ No newline at end of file