Merge pull request #42 from EveryInc/experiment-runner

Experiment runner
This commit is contained in:
Tyler Marques 2025-06-22 09:34:55 -07:00 committed by GitHub
commit 617ccdb1a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 943 additions and 80 deletions

View file

@ -236,13 +236,92 @@ python lm_game.py --max_year 1910 --planning_phase --num_negotiation_rounds 2
# Custom model assignment (order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY)
python lm_game.py --models "claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022"
# Output to specific file
python lm_game.py --output results/my_game.json
# Run until game completion or specific year
python lm_game.py --num_negotiation_rounds 2 --planning_phase
# Write all artefacts to a chosen directory (auto-resumes if it already exists)
python lm_game.py --run_dir results/game_run_001
# Resume an interrupted game from a specific phase
python lm_game.py --run_dir results/game_run_001 --resume_from_phase S1902M
# Critical-state analysis: resume from an existing run but save new results elsewhere
python lm_game.py \
--run_dir results/game_run_001 \
--critical_state_analysis_dir results/critical_analysis_001 \
--resume_from_phase F1903M
# End the simulation after a particular phase regardless of remaining years
python lm_game.py --run_dir results/game_run_002 --end_at_phase F1905M
# Set the global max_tokens generation limit
python lm_game.py --run_dir results/game_run_003 --max_tokens 8000
# Per-model token limits (AU,EN,FR,GE,IT,RU,TR)
python lm_game.py --run_dir results/game_run_004 \
--max_tokens_per_model "8000,8000,16000,8000,8000,16000,8000"
# Use a custom prompts directory
python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants
```
### Running Batch Experiments with **`experiment_runner.py`**
`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify.
All flags that belong to **`lm_game.py`** can be passed straight through; the runner validates them and forwards them unchanged to every game instance.
---
#### Examples
```bash
# Run 10 independent games (iterations) in parallel, using a custom prompts dir
# and a single model (GPT-4o) for all seven powers.
python3 experiment_runner.py \
--experiment_dir "results/exp001" \
--iterations 10 \
--parallel 10 \
--max_year 1905 \
--num_negotiation_rounds 0 \
--prompts_dir "ai_diplomacy/prompts" \
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
# Critical-state analysis: resume every run from W1901A (taken from an existing
# base run) and stop after S1902M. Two analysis modules will be executed:
# • summary → aggregated results & scores
# • critical_state → before/after snapshots around the critical phase
python3 experiment_runner.py \
--experiment_dir "results/exp002" \
--iterations 10 \
--parallel 10 \
--resume_from_phase W1901A \
--end_at_phase S1902M \
--num_negotiation_rounds 0 \
--critical_state_base_run "results/test1" \
--prompts_dir "ai_diplomacy/prompts" \
--analysis_modules "summary,critical_state" \
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
```
*(Any other `lm_game.py` flags—`--planning_phase`, `--max_tokens`, etc.—can be added exactly where youd use them on a single-game run.)*
---
#### Experiment-runnerspecific arguments
| Flag | Type / Default | Description |
| --------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `--experiment_dir` **(required)** | `Path` | Root folder for the experiment; sub-folders `runs/` and `analysis/` are managed automatically. Re-running with the same directory will **resume** existing runs and regenerate analysis. |
| `--iterations` | `int`, default `1` | How many individual games to launch for this experiment. |
| `--parallel` | `int`, default `1` | Max number of games to execute concurrently (uses a process pool). |
| `--analysis_modules` | `str`, default `"summary"` | Comma-separated list of analysis modules to run after all games finish. Modules are imported from `experiment_runner.analysis.<name>` and must provide `run(experiment_dir, ctx)`. |
| `--critical_state_base_run` | `Path`, optional | Path to an **existing** `run_dir` produced by a previous `lm_game` run. Each iteration resumes from that snapshot; new artefacts are written under the current `experiment_dir`. |
| `--seed_base` | `int`, default `42` | Base random seed. Run *ɪ* receives seed = `seed_base + ɪ`, enabling reproducible batches. |
*(All other command-line flags belong to `lm_game.py` and are forwarded unchanged.)*
### Environment Setup
Create a `.env` file with your API keys:

View file

@ -22,13 +22,17 @@ ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUS
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
# == New: Helper function to load prompt files reliably ==
def _load_prompt_file(filename: str) -> Optional[str]:
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
"""Loads a prompt template from the prompts directory."""
try:
# Construct path relative to this file's location
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, 'prompts')
filepath = os.path.join(prompts_dir, filename)
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
# Construct path relative to this file's location
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts')
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
@ -50,6 +54,7 @@ class DiplomacyAgent:
client: BaseModelClient,
initial_goals: Optional[List[str]] = None,
initial_relationships: Optional[Dict[str, str]] = None,
prompts_dir: Optional[str] = None,
):
"""
Initializes the DiplomacyAgent.
@ -60,12 +65,14 @@ class DiplomacyAgent:
initial_goals: An optional list of initial strategic goals.
initial_relationships: An optional dictionary mapping other power names to
relationship statuses (e.g., 'ALLY', 'ENEMY', 'NEUTRAL').
prompts_dir: Optional path to the prompts directory.
"""
if power_name not in ALL_POWERS:
raise ValueError(f"Invalid power name: {power_name}. Must be one of {ALL_POWERS}")
self.power_name: str = power_name
self.client: BaseModelClient = client
self.prompts_dir: Optional[str] = prompts_dir
# Initialize goals as empty list, will be populated by initialize_agent_state
self.goals: List[str] = initial_goals if initial_goals is not None else []
# Initialize relationships to Neutral if not provided
@ -85,16 +92,21 @@ class DiplomacyAgent:
# Get the directory containing the current file (agent.py)
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct path relative to the current file's directory
prompts_dir = os.path.join(current_dir, "prompts")
power_prompt_filename = os.path.join(prompts_dir, f"{power_name.lower()}_system_prompt.txt")
default_prompt_filename = os.path.join(prompts_dir, "system_prompt.txt")
default_prompts_path = os.path.join(current_dir, "prompts")
power_prompt_filename = f"{power_name.lower()}_system_prompt.txt"
default_prompt_filename = "system_prompt.txt"
system_prompt_content = load_prompt(power_prompt_filename)
# Use the provided prompts_dir if available, otherwise use the default
prompts_path_to_use = self.prompts_dir if self.prompts_dir else default_prompts_path
power_prompt_filepath = os.path.join(prompts_path_to_use, power_prompt_filename)
default_prompt_filepath = os.path.join(prompts_path_to_use, default_prompt_filename)
system_prompt_content = load_prompt(power_prompt_filepath, prompts_dir=self.prompts_dir)
if not system_prompt_content:
logger.warning(f"Power-specific prompt '{power_prompt_filename}' not found or empty. Loading default system prompt.")
# system_prompt_content = load_prompt("system_prompt.txt")
system_prompt_content = load_prompt(default_prompt_filename)
logger.warning(f"Power-specific prompt '{power_prompt_filepath}' not found or empty. Loading default system prompt.")
system_prompt_content = load_prompt(default_prompt_filepath, prompts_dir=self.prompts_dir)
else:
logger.info(f"Loaded power-specific system prompt for {power_name}.")
# ----------------------------------------------------
@ -416,7 +428,7 @@ class DiplomacyAgent:
try:
# Load the template file but safely preprocess it first
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt')
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt', prompts_dir=self.prompts_dir)
if not prompt_template_content:
logger.error(f"[{self.power_name}] Could not load negotiation_diary_prompt.txt. Skipping diary entry.")
success_status = "Failure: Prompt file not loaded"
@ -611,7 +623,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
# Load the template but we'll use it carefully with string interpolation
prompt_template = _load_prompt_file('order_diary_prompt.txt')
prompt_template = _load_prompt_file('order_diary_prompt.txt', prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{self.power_name}] Could not load order_diary_prompt.txt. Skipping diary entry.")
return
@ -756,7 +768,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
# Load the template
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt')
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
return
@ -859,7 +871,7 @@ class DiplomacyAgent:
try:
# 1. Construct the prompt using the dedicated state update prompt file
prompt_template = _load_prompt_file('state_update_prompt.txt')
prompt_template = _load_prompt_file('state_update_prompt.txt', prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{power_name}] Could not load state_update_prompt.txt. Skipping state update.")
return
@ -893,6 +905,7 @@ class DiplomacyAgent:
agent_goals=self.goals,
agent_relationships=self.relationships,
agent_private_diary=formatted_diary, # Pass formatted diary
prompts_dir=self.prompts_dir,
)
# Add previous phase summary to the information provided to the LLM

View file

@ -44,10 +44,11 @@ class BaseModelClient:
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
"""
def __init__(self, model_name: str):
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
self.model_name = model_name
self.prompts_dir = prompts_dir
# Load a default initially, can be overwritten by set_system_prompt
self.system_prompt = load_prompt("system_prompt.txt")
self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
self.max_tokens = 16000 # default unless overridden
def set_system_prompt(self, content: str):
@ -97,6 +98,7 @@ class BaseModelClient:
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary_str=agent_private_diary_str,
prompts_dir=self.prompts_dir,
)
raw_response = ""
@ -423,7 +425,7 @@ class BaseModelClient:
agent_private_diary_str: Optional[str] = None, # Added
) -> str:
instructions = load_prompt("planning_instructions.txt")
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
context = self.build_context_prompt(
game,
@ -434,6 +436,7 @@ class BaseModelClient:
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
return context + "\n\n" + instructions
@ -451,7 +454,7 @@ class BaseModelClient:
agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None, # Added
) -> str:
instructions = load_prompt("conversation_instructions.txt")
instructions = load_prompt("conversation_instructions.txt", prompts_dir=self.prompts_dir)
context = build_context_prompt(
game,
@ -462,6 +465,7 @@ class BaseModelClient:
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
# Get recent messages targeting this power to prioritize responses
@ -699,7 +703,7 @@ class BaseModelClient:
"""
logger.info(f"Client generating strategic plan for {power_name}...")
planning_instructions = load_prompt("planning_instructions.txt")
planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
if not planning_instructions:
logger.error("Could not load planning_instructions.txt! Cannot generate plan.")
return "Error: Planning instructions not found."
@ -718,6 +722,7 @@ class BaseModelClient:
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
full_prompt = f"{context_prompt}\n\n{planning_instructions}"
@ -772,8 +777,8 @@ class OpenAIClient(BaseModelClient):
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
"""
def __init__(self, model_name: str):
super().__init__(model_name)
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
@ -819,8 +824,8 @@ class ClaudeClient(BaseModelClient):
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
"""
def __init__(self, model_name: str):
super().__init__(model_name)
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
@ -861,8 +866,8 @@ class GeminiClient(BaseModelClient):
For 'gemini-1.5-flash' or other Google Generative AI models.
"""
def __init__(self, model_name: str):
super().__init__(model_name)
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
# Configure and get the model (corrected initialization)
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
@ -905,8 +910,8 @@ class DeepSeekClient(BaseModelClient):
For DeepSeek R1 'deepseek-reasoner'
"""
def __init__(self, model_name: str):
super().__init__(model_name)
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
self.client = AsyncDeepSeekOpenAI(
api_key=self.api_key,
@ -961,8 +966,8 @@ class OpenAIResponsesClient(BaseModelClient):
This client makes direct HTTP requests to the v1/responses endpoint.
"""
def __init__(self, model_name: str):
super().__init__(model_name)
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable is required")
@ -1068,14 +1073,14 @@ class OpenRouterClient(BaseModelClient):
For OpenRouter models, with default being 'openrouter/quasar-alpha'
"""
def __init__(self, model_name: str = "openrouter/quasar-alpha"):
def __init__(self, model_name: str = "openrouter/quasar-alpha", prompts_dir: Optional[str] = None):
# Allow specifying just the model identifier or the full path
if not model_name.startswith("openrouter/") and "/" not in model_name:
model_name = f"openrouter/{model_name}"
if model_name.startswith("openrouter-"):
model_name = model_name.replace("openrouter-", "")
super().__init__(model_name)
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("OPENROUTER_API_KEY")
if not self.api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
@ -1146,8 +1151,8 @@ class TogetherAIClient(BaseModelClient):
Model names should be passed without the 'together-' prefix.
"""
def __init__(self, model_name: str):
super().__init__(model_name) # model_name here is the actual Together AI model identifier
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir) # model_name here is the actual Together AI model identifier
self.api_key = os.environ.get("TOGETHER_API_KEY")
if not self.api_key:
raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient")
@ -1198,12 +1203,13 @@ class TogetherAIClient(BaseModelClient):
##############################################################################
def load_model_client(model_id: str) -> BaseModelClient:
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
"""
Returns the appropriate LLM client for a given model_id string.
Args:
model_id: The model identifier
prompts_dir: Optional path to the prompts directory.
Example usage:
client = load_model_client("claude-3-5-sonnet-20241022")
@ -1213,23 +1219,23 @@ def load_model_client(model_id: str) -> BaseModelClient:
# Check for o3-pro model specifically - it needs the Responses API
if lower_id == "o3-pro":
return OpenAIResponsesClient(model_id)
return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir)
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
elif model_id.startswith("together-"):
actual_model_name = model_id.split("together-", 1)[1]
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
return TogetherAIClient(actual_model_name)
return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir)
elif "openrouter" in model_id.lower() or "/" in model_id: # More general check for OpenRouterClient(model_id)
return OpenRouterClient(model_id)
return OpenRouterClient(model_id, prompts_dir=prompts_dir)
elif "claude" in lower_id:
return ClaudeClient(model_id)
return ClaudeClient(model_id, prompts_dir=prompts_dir)
elif "gemini" in lower_id:
return GeminiClient(model_id)
return GeminiClient(model_id, prompts_dir=prompts_dir)
elif "deepseek" in lower_id:
return DeepSeekClient(model_id)
return DeepSeekClient(model_id, prompts_dir=prompts_dir)
else:
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
return OpenAIClient(model_id)
return OpenAIClient(model_id, prompts_dir=prompts_dir)
##############################################################################
@ -1249,4 +1255,4 @@ def get_visible_messages_for_power(conversation_messages, power_name):
or msg["recipient"] == power_name
):
visible.append(msg)
return visible # already in chronological order if appended that way
return visible # already in chronological order if appended that way

View file

@ -1,7 +1,7 @@
# ai_diplomacy/diary_logic.py
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from .utils import run_llm_and_log, log_llm_response
@ -11,13 +11,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str) -> str | None:
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, 'prompts')
filepath = os.path.join(prompts_dir, filename)
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts')
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
@ -29,6 +33,7 @@ async def run_diary_consolidation(
game: "Game",
log_file_path: str,
entries_to_keep_unsummarized: int = 15,
prompts_dir: Optional[str] = None,
):
"""
Consolidate older diary entries while keeping recent ones.
@ -93,7 +98,7 @@ async def run_diary_consolidation(
)
return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt")
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template:
logger.error(
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"

View file

@ -31,16 +31,17 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary,
}
def deserialize_agent(agent_data: dict) -> DiplomacyAgent:
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent:
"""Recreates an agent object from a dictionary."""
client = load_model_client(agent_data["model_id"])
client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir)
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
agent = DiplomacyAgent(
power_name=agent_data["power_name"],
client=client,
initial_goals=agent_data.get("goals", []),
initial_relationships=agent_data.get("relationships", None)
initial_relationships=agent_data.get("relationships", None),
prompts_dir=prompts_dir
)
# Restore the diary.
agent.full_private_diary = agent_data.get("full_private_diary", [])
@ -179,7 +180,7 @@ def load_game_state(
run_dir: str,
game_file_name: str,
resume_from_phase: Optional[str] = None
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Namespace]:
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
"""Loads and reconstructs the game state from a saved game file."""
game_file_path = os.path.join(run_dir, game_file_name)
if not os.path.exists(game_file_path):
@ -189,6 +190,14 @@ def load_game_state(
with open(game_file_path, 'r') as f:
saved_game_data = json.load(f)
# Find the latest config saved in the file
run_config = None
if saved_game_data.get("phases"):
for phase in reversed(saved_game_data["phases"]):
if "config" in phase:
run_config = Namespace(**phase["config"])
logger.info(f"Loaded run configuration from phase {phase['name']}.")
break
# If resuming, find the specified phase and truncate the data after it
if resume_from_phase:
@ -231,8 +240,6 @@ def load_game_state(
logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop
game_history = GameHistory()
# No config to load, will use current run's args
run_config = None
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
@ -245,8 +252,9 @@ def load_game_state(
agents = {}
if 'state_agents' in last_phase_data:
logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None
for power_name, agent_data in last_phase_data['state_agents'].items():
agents[power_name] = deserialize_agent(agent_data)
agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
@ -267,7 +275,7 @@ def load_game_state(
logger.info("Game history rebuilt.")
return game, agents, game_history
return game, agents, game_history, run_config
async def initialize_new_game(
@ -308,12 +316,12 @@ async def initialize_new_game(
for power_name, model_id in game.power_model_map.items():
if not game.powers[power_name].is_eliminated():
try:
client = load_model_client(model_id)
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
client.max_tokens = model_max_tokens[power_name]
agent = DiplomacyAgent(power_name=power_name, client=client)
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
agents[power_name] = agent
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path))
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
except Exception as e:
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)

View file

@ -1,6 +1,7 @@
# ai_diplomacy/initialization.py
import logging
import json
from typing import Optional
# Forward declaration for type hinting, actual imports in function if complex
if False: # TYPE_CHECKING
@ -18,7 +19,8 @@ async def initialize_agent_state_ext(
agent: 'DiplomacyAgent',
game: 'Game',
game_history: 'GameHistory',
log_file_path: str
log_file_path: str,
prompts_dir: Optional[str] = None,
):
"""Uses the LLM to set initial goals and relationships for the agent."""
power_name = agent.power_name
@ -56,7 +58,8 @@ async def initialize_agent_state_ext(
game_history=game_history,
agent_goals=None,
agent_relationships=None,
agent_private_diary=formatted_diary,
agent_private_diary=formatted_diary,
prompts_dir=prompts_dir,
)
full_prompt = initial_prompt + "\n\n" + context

View file

@ -23,6 +23,7 @@ def build_context_prompt(
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary: Optional[str] = None,
prompts_dir: Optional[str] = None,
) -> str:
"""Builds the detailed context part of the prompt.
@ -35,11 +36,12 @@ def build_context_prompt(
agent_goals: Optional list of agent's goals.
agent_relationships: Optional dictionary of agent's relationships with other powers.
agent_private_diary: Optional string of agent's private diary.
prompts_dir: Optional path to the prompts directory.
Returns:
A string containing the formatted context.
"""
context_template = load_prompt("context_prompt.txt")
context_template = load_prompt("context_prompt.txt", prompts_dir=prompts_dir)
# === Agent State Debug Logging ===
if agent_goals:
@ -112,6 +114,7 @@ def construct_order_generation_prompt(
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None,
prompts_dir: Optional[str] = None,
) -> str:
"""Constructs the final prompt for order generation.
@ -125,13 +128,14 @@ def construct_order_generation_prompt(
agent_goals: Optional list of agent's goals.
agent_relationships: Optional dictionary of agent's relationships with other powers.
agent_private_diary_str: Optional string of agent's private diary.
prompts_dir: Optional path to the prompts directory.
Returns:
A string containing the complete prompt for the LLM.
"""
# Load prompts
_ = load_prompt("few_shot_example.txt") # Loaded but not used, as per original logic
instructions = load_prompt("order_instructions.txt")
_ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic
instructions = load_prompt("order_instructions.txt", prompts_dir=prompts_dir)
# Build the context prompt
context = build_context_prompt(
@ -143,7 +147,8 @@ def construct_order_generation_prompt(
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary=agent_private_diary_str,
prompts_dir=prompts_dir,
)
final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions
return final_prompt
return final_prompt

View file

@ -293,11 +293,14 @@ def normalize_and_compare_orders(
# Helper to load prompt text from file relative to the expected 'prompts' dir
def load_prompt(filename: str) -> str:
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
"""Helper to load prompt text from file"""
# Assuming execution from the root or that the path resolves correctly
# Consider using absolute paths or pkg_resources if needed for robustness
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
else:
# Default behavior: relative to this file's location in the 'prompts' subdir
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
try:
with open(prompt_path, "r", encoding='utf-8') as f: # Added encoding
return f.read().strip()

474
experiment_runner.py Normal file
View file

@ -0,0 +1,474 @@
#!/usr/bin/env python3
"""
Experiment orchestration for Diplomacy self-play.
Launches many `lm_game` runs in parallel, captures their artefacts,
and executes a pluggable post-analysis pipeline.
Run `python experiment_runner.py --help` for CLI details.
"""
from __future__ import annotations
import argparse
import collections
import concurrent.futures
import importlib
import json
import logging
import math
import os
import shutil
import subprocess
import sys
import textwrap
import time
import multiprocessing as mp
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Iterable, List
# --------------------------------------------------------------------------- #
# Logging #
# --------------------------------------------------------------------------- #
LOG_FMT = "%(asctime)s [%(levelname)s] %(name)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%H:%M:%S")
log = logging.getLogger("experiment_runner")
# ────────────────────────────────────────────────────────────────────────────
# Flag definitions full, un-shortened help strings #
# ────────────────────────────────────────────────────────────────────────────
def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
p.add_argument(
"--experiment_dir",
type=Path,
required=True,
help=(
"Directory that will hold all experiment artefacts. "
"A 'runs/' sub-folder is created for individual game runs and an "
"'analysis/' folder for aggregated outputs. Must be writable."
),
)
p.add_argument(
"--iterations",
type=int,
default=1,
help=(
"Number of lm_game instances to launch for this experiment. "
"Each instance gets its own sub-directory under runs/."
),
)
p.add_argument(
"--parallel",
type=int,
default=1,
help=(
"Maximum number of game instances to run concurrently. "
"Uses a ProcessPoolExecutor under the hood."
),
)
p.add_argument(
"--analysis_modules",
type=str,
default="summary",
help=(
"Comma-separated list of analysis module names to execute after all "
"runs finish. Modules are imported from "
"'experiment_runner.analysis.<name>' and must expose "
"run(experiment_dir: Path, ctx: dict)."
),
)
p.add_argument(
"--critical_state_base_run",
type=Path,
default=None,
help=(
"Path to an *existing* run directory produced by a previous lm_game "
"execution. When supplied, every iteration resumes from that "
"snapshot using lm_game's --critical_state_analysis_dir mechanism."
),
)
p.add_argument(
"--seed_base",
type=int,
default=42,
help=(
"Base RNG seed. Run i will receive seed = seed_base + i. "
"Forwarded to lm_game via its --seed flag (you must have added that "
"flag to lm_game for deterministic behaviour)."
),
)
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
# ---- all flags copied verbatim from lm_game.parse_arguments() ----
p.add_argument(
"--resume_from_phase",
type=str,
default="",
help=(
"Phase to resume from (e.g., 'S1902M'). Requires --run_dir. "
"IMPORTANT: This option clears any existing phase results ahead of "
"& including the specified resume phase."
),
)
p.add_argument(
"--end_at_phase",
type=str,
default="",
help="Phase to end the simulation after (e.g., 'F1905M').",
)
p.add_argument(
"--max_year",
type=int,
default=1910, # Increased default in lm_game
help="Maximum year to simulate. The game will stop once this year is reached.",
)
p.add_argument(
"--num_negotiation_rounds",
type=int,
default=0,
help="Number of negotiation rounds per phase.",
)
p.add_argument(
"--models",
type=str,
default="",
help=(
"Comma-separated list of model names to assign to powers in order. "
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
),
)
p.add_argument(
"--planning_phase",
action="store_true",
help="Enable the planning phase for each power to set strategic directives.",
)
p.add_argument(
"--max_tokens",
type=int,
default=16000,
help="Maximum number of new tokens to generate per LLM call (default: 16000).",
)
p.add_argument(
"--max_tokens_per_model",
type=str,
default="",
help=(
"Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, "
"FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
),
)
p.add_argument(
"--prompts_dir",
type=str,
default=None,
help=(
"Path to the directory containing prompt files. "
"Defaults to the packaged prompts directory."
),
)
# ────────────────────────────────────────────────────────────────────────────
# One combined parser for banner printing #
# ────────────────────────────────────────────────────────────────────────────
def _build_full_parser() -> argparse.ArgumentParser:
fp = argparse.ArgumentParser(
prog="experiment_runner.py",
formatter_class=lambda prog: argparse.RawTextHelpFormatter(
prog, max_help_position=45
),
description=(
"Batch-runner for Diplomacy self-play experiments. "
"All lm_game flags are accepted here as-is; they are validated "
"before any game runs start."
),
)
_add_experiment_flags(fp)
_add_lm_game_flags(fp)
return fp
# ────────────────────────────────────────────────────────────────────────────
# Robust parsing that always shows *full* help on error #
# ────────────────────────────────────────────────────────────────────────────
def _parse_cli() -> tuple[argparse.Namespace, list[str], argparse.Namespace]:
full_parser = _build_full_parser()
# Show full banner when no args
if len(sys.argv) == 1:
full_parser.print_help(sys.stderr)
sys.exit(2)
# Show full banner on explicit help
if any(tok in ("-h", "--help") for tok in sys.argv[1:]):
full_parser.print_help(sys.stderr)
sys.exit(0)
# Sub-parsers for separating experiment vs game flags
class _ErrParser(argparse.ArgumentParser):
def error(self, msg):
full_parser.print_help(sys.stderr)
self.exit(2, f"{self.prog}: error: {msg}\n")
exp_parser = _ErrParser(add_help=False)
game_parser = _ErrParser(add_help=False)
_add_experiment_flags(exp_parser)
_add_lm_game_flags(game_parser)
# Split argv tokens by flag ownership
argv = sys.argv[1:]
exp_flag_set = {opt for a in exp_parser._actions for opt in a.option_strings}
exp_tok, game_tok, i = [], [], 0
while i < len(argv):
tok = argv[i]
if tok in exp_flag_set:
exp_tok.append(tok)
action = exp_parser._option_string_actions[tok]
needs_val = (
action.nargs is None # default: exactly one value
or (isinstance(action.nargs, int) and action.nargs > 0)
or action.nargs in ("+", "*", "?") # variable-length cases
)
if needs_val:
exp_tok.append(argv[i + 1])
i += 2
else: # store_true / store_false
i += 1
else:
game_tok.append(tok)
i += 1
exp_args = exp_parser.parse_args(exp_tok)
game_args = game_parser.parse_args(game_tok)
return exp_args, game_tok, game_args
# --------------------------------------------------------------------------- #
# Helpers #
# --------------------------------------------------------------------------- #
_RunInfo = collections.namedtuple(
"_RunInfo", "index run_dir seed cmd_line returncode elapsed_s"
)
def _mk_run_dir(exp_dir: Path, idx: int) -> Path:
run_dir = exp_dir / "runs" / f"run_{idx:05d}"
# Just ensure it exists; don't raise if it already does.
run_dir.mkdir(parents=True, exist_ok=True)
return run_dir
def _dump_seed(seed: int, run_dir: Path) -> None:
seed_file = run_dir / "seed.txt"
if not seed_file.exists():
seed_file.write_text(str(seed))
def _build_cmd(
lm_game_script: Path,
base_cli: List[str],
run_dir: Path,
seed: int,
critical_base: Path | None,
resume_from_phase: str,
) -> List[str]:
"""
Returns a list suitable for subprocess.run([...]).
"""
cmd = [sys.executable, str(lm_game_script)]
# Forward user CLI
cmd.extend(base_cli)
# Per-run mandatory overrides
cmd.extend(["--run_dir", str(run_dir)])
cmd.extend(["--seed", str(seed)]) # you may need to add a --seed flag to lm_game
# Critical-state mode
if critical_base:
cmd.extend([
"--critical_state_analysis_dir", str(run_dir),
"--run_dir", str(critical_base) # base run dir (already completed)
])
if resume_from_phase:
cmd.extend(["--resume_from_phase", resume_from_phase])
return cmd
def _launch_one(args) -> _RunInfo:
"""
Worker executed inside a ProcessPool; runs one game via subprocess.
"""
(
idx,
lm_game_script,
base_cli,
run_dir,
seed,
critical_base,
resume_phase,
) = args
cmd = _build_cmd(
lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase
)
start = time.perf_counter()
log.debug("Run %05d: CMD = %s", idx, " ".join(cmd))
# Write out full command for traceability
(run_dir / "command.txt").write_text(" ".join(cmd))
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
check=False,
)
(run_dir / "console.log").write_text(result.stdout)
rc = result.returncode
except Exception as exc: # noqa: broad-except
(run_dir / "console.log").write_text(f"Exception launching run:\n{exc}\n")
rc = 1
elapsed = time.perf_counter() - start
return _RunInfo(idx, run_dir, seed, " ".join(cmd), rc, elapsed)
def _load_analysis_fns(module_names: Iterable[str]):
"""
Dynamically import analysis modules.
Each module must expose `run(experiment_dir: Path, cfg: dict)`.
"""
for name in module_names:
mod_name = f"experiment_runner.analysis.{name.strip()}"
try:
mod = importlib.import_module(mod_name)
except ModuleNotFoundError as e:
log.warning("Analysis module %s not found (%s) skipping", mod_name, e)
continue
if not hasattr(mod, "run"):
log.warning("%s has no `run()` function skipping", mod_name)
continue
yield mod.run
# --------------------------------------------------------------------------- #
# Main driver #
# --------------------------------------------------------------------------- #
def main() -> None:
exp_args, leftover_cli, game_args = _parse_cli()
exp_dir: Path = exp_args.experiment_dir.expanduser().resolve()
if exp_dir.exists():
log.info("Appending to existing experiment: %s", exp_dir)
exp_dir.mkdir(parents=True, exist_ok=True)
# Persist experiment-level config
cfg_path = exp_dir / "config.json"
if not cfg_path.exists(): # ← new guard
with cfg_path.open("w", encoding="utf-8") as fh:
json.dump(
{"experiment": vars(exp_args),
"lm_game": vars(game_args),
"forwarded_cli": leftover_cli},
fh, indent=2, default=str,
)
log.info("Config saved to %s", cfg_path)
else:
log.info("Config already exists leaving unchanged")
log.info("Config saved to %s", cfg_path)
# ------------------------------------------------------------------ #
# Launch games #
# ------------------------------------------------------------------ #
lm_game_script = Path(__file__).parent / "lm_game.py"
if not lm_game_script.exists():
log.error("lm_game.py not found at %s abort", lm_game_script)
sys.exit(1)
run_args = []
for i in range(exp_args.iterations):
run_dir = _mk_run_dir(exp_dir, i)
seed = exp_args.seed_base + i
_dump_seed(seed, run_dir)
run_args.append(
(
i, lm_game_script, leftover_cli, run_dir, seed,
exp_args.critical_state_base_run,
game_args.resume_from_phase,
)
)
log.info(
"Launching %d runs (max %d parallel, critical_state=%s)",
exp_args.iterations,
exp_args.parallel,
bool(exp_args.critical_state_base_run),
)
runs_meta: list[_RunInfo] = []
with concurrent.futures.ProcessPoolExecutor(
max_workers=exp_args.parallel,
mp_context=mp.get_context("spawn"),
) as pool:
for res in pool.map(_launch_one, run_args):
runs_meta.append(res)
status = "OK" if res.returncode == 0 else f"RC={res.returncode}"
log.info(
"run_%05d finished in %.1fs %s", res.index, res.elapsed_s, status
)
# Persist per-run status summary
summary_path = exp_dir / "runs_summary.json"
with open(summary_path, "w", encoding="utf-8") as fh:
json.dump([res._asdict() for res in runs_meta], fh, indent=2, default=str)
log.info("Run summary written → %s", summary_path)
# ------------------------------------------------------------------ #
# Post-analysis pipeline #
# ------------------------------------------------------------------ #
mods = list(_load_analysis_fns(exp_args.analysis_modules.split(",")))
if not mods:
log.warning("No analysis modules loaded done.")
return
analysis_root = exp_dir / "analysis"
if analysis_root.exists():
shutil.rmtree(analysis_root) # ← wipes old outputs
analysis_root.mkdir(exist_ok=True)
# Collect common context
ctx: dict = {
"exp_dir": str(exp_dir),
"runs": [str(r.run_dir) for r in runs_meta],
"critical_state_base": str(exp_args.critical_state_base_run or ""),
"resume_from_phase": game_args.resume_from_phase,
}
for fn in mods:
name = fn.__module__.rsplit(".", 1)[-1]
log.info("Running analysis module: %s", name)
try:
fn(exp_dir, ctx)
log.info("%s complete", name)
except Exception as exc: # noqa: broad-except
log.exception("Analysis module %s failed: %s", name, exc)
log.info("Experiment finished artefacts in %s", exp_dir)
if __name__ == "__main__":
main()

View file

View file

View file

@ -0,0 +1,60 @@
"""
Extracts the board state *before* and *after* a critical phase
for every critical-analysis run produced by experiment_runner.
Each run must have:
lmvsgame.json containing a phase named ctx["resume_from_phase"]
Outputs live in analysis/critical_state/<run_name>_{before|after}.json
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Dict, List
log = logging.getLogger(__name__)
def _phase_by_name(phases: List[dict], name: str) -> dict | None:
for ph in phases:
if ph["state"]["name"] == name:
return ph
return None
def run(exp_dir: Path, ctx: Dict) -> None:
resume_phase = ctx.get("resume_from_phase")
if not resume_phase:
log.info("critical_state: --resume_from_phase not supplied skipping")
return
out_dir = exp_dir / "analysis" / "critical_state"
out_dir.mkdir(parents=True, exist_ok=True)
for run_dir in (exp_dir / "runs").iterdir():
game_json = run_dir / "lmvsgame.json"
if not game_json.exists():
continue
with game_json.open("r") as fh:
game = json.load(fh)
phases = game.get("phases", [])
if not phases:
continue
before = _phase_by_name(phases, resume_phase)
after = phases[-1] if phases else None
if before is None or after is None:
log.warning("Run %s missing expected phases skipped", run_dir.name)
continue
(out_dir / f"{run_dir.name}_before.json").write_text(
json.dumps(before, indent=2)
)
(out_dir / f"{run_dir.name}_after.json").write_text(
json.dumps(after, indent=2)
)
log.info("critical_state: snapshots written to %s", out_dir)

View file

@ -0,0 +1,177 @@
"""
Aggregates results across all runs and writes:
analysis/aggregated_results.csv
analysis/score_summary_by_power.csv
analysis/results_summary.png (if matplotlib available)
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any, Dict, List
import pandas as pd
log = logging.getLogger(__name__)
# --------------------------------------------------------------------------- #
# Helpers copied (and lightly cleaned) from run_games.py #
# --------------------------------------------------------------------------- #
def _extract_diplomacy_results(game_json: Path) -> Dict[str, Dict[str, Any]]:
with game_json.open("r", encoding="utf-8") as fh:
gd = json.load(fh)
phases = gd.get("phases", [])
if not phases:
raise ValueError("no phases")
first_state = phases[0]["state"]
last_state = phases[-1]["state"]
powers = (
list(first_state.get("homes", {}).keys())
or list(first_state.get("centers", {}).keys())
)
if not powers:
raise ValueError("cannot determine powers")
scs_to_win = 18
solo_winner = next(
(p for p, sc in last_state["centers"].items() if len(sc) >= scs_to_win), None
)
results: Dict[str, Dict[str, Any]] = {}
for p in powers:
sc_count = len(last_state["centers"].get(p, []))
units = len(last_state["units"].get(p, []))
if solo_winner:
if p == solo_winner:
outcome, cat = f"Won solo", "Solo Win"
else:
outcome, cat = f"Lost to {solo_winner}", "Loss"
elif sc_count == 0 and units == 0:
outcome, cat = "Eliminated", "Eliminated"
else:
outcome, cat = "Ongoing/Draw", "Ongoing/Abandoned/Draw"
results[p] = {
"OutcomeCategory": cat,
"StatusDetail": outcome,
"SupplyCenters": sc_count,
"LastPhase": last_state["name"],
}
return results
# simplistic "Diplobench" scoring from previous discussion ------------------ #
def _year(name: str) -> int | None:
m = re.search(r"(\d{4})", name)
return int(m.group(1)) if m else None
def _score_game(game_json: Path) -> Dict[str, int]:
with open(game_json, "r") as fh:
game = json.load(fh)
phases = game["phases"]
if not phases:
return {}
start = _year(phases[0]["state"]["name"])
end_year = _year(phases[-1]["state"]["name"]) or start
max_turns = (end_year - start + 1) if start is not None else len(phases)
last_state = phases[-1]["state"]
solo_winner = next(
(p for p, scs in last_state["centers"].items() if len(scs) >= 18), None
)
elim_turn: Dict[str, int | None] = {}
for p in last_state["centers"].keys():
e_turn = None
for idx, ph in enumerate(phases):
if not ph["state"]["centers"].get(p, []):
yr = _year(ph["state"]["name"]) or 0
e_turn = (yr - start + 1) if start is not None else idx + 1
break
elim_turn[p] = e_turn
scores: Dict[str, int] = {}
for p, scs in last_state["centers"].items():
if p == solo_winner:
win_turn = (end_year - start + 1) if start is not None else len(phases)
scores[p] = max_turns + 17 + (max_turns - win_turn)
elif solo_winner:
# losers in a solo game
yr_win = _year(phases[-1]["state"]["name"]) or end_year
turn_win = (yr_win - start + 1) if start is not None else len(phases)
scores[p] = turn_win
else:
if elim_turn[p] is None:
scores[p] = max_turns + len(scs)
else:
scores[p] = elim_turn[p]
return scores
# --------------------------------------------------------------------------- #
# Public entry point #
# --------------------------------------------------------------------------- #
def run(exp_dir: Path, ctx: dict): # pylint: disable=unused-argument
analysis_dir = exp_dir / "analysis"
analysis_dir.mkdir(exist_ok=True)
rows: List[Dict[str, Any]] = []
for run_dir in (exp_dir / "runs").iterdir():
game_json = run_dir / "lmvsgame.json"
if not game_json.exists():
continue
gid = run_dir.name
try:
res = _extract_diplomacy_results(game_json)
except Exception as e: # noqa: broad-except
log.warning("Could not parse %s (%s)", game_json, e)
continue
scores = _score_game(game_json)
for pwr, info in res.items():
out = {"GameID": gid, "Power": pwr, **info, "Score": scores.get(pwr, None)}
rows.append(out)
if not rows:
log.warning("summary: no parsable runs found")
return
df = pd.DataFrame(rows)
out_csv = analysis_dir / "aggregated_results.csv"
df.to_csv(out_csv, index=False)
summary = (
df.groupby("Power")["Score"]
.agg(["mean", "median", "count"])
.reset_index()
.rename(columns={"count": "n"})
)
summary.to_csv(analysis_dir / "score_summary_by_power.csv", index=False)
log.info("summary: wrote %s rows to %s", len(df), out_csv)
# Optional charts
try:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
plt.figure(figsize=(10, 7))
sns.boxplot(x="Power", y="SupplyCenters", data=df, palette="pastel")
plt.title("Supply-center distribution")
plt.savefig(analysis_dir / "results_summary.png", dpi=150)
plt.close()
log.info("summary: chart saved")
except Exception as e: # noqa: broad-except
log.debug("Chart generation skipped (%s)", e)

View file

@ -55,6 +55,12 @@ def parse_arguments():
default="",
help="Directory for results. If it exists, the game resumes. If not, it's created. Defaults to a new timestamped directory.",
)
parser.add_argument(
"--output", # alias for back compatibility
dest="run_dir", # write to the same variable as --run_dir
type=str,
help=argparse.SUPPRESS # hides it from `--help`
)
parser.add_argument(
"--critical_state_analysis_dir",
type=str,
@ -105,12 +111,24 @@ def parse_arguments():
default=16000,
help="Maximum number of new tokens to generate per LLM call (default: 16000)."
)
parser.add_argument(
"--seed_base",
type=int,
default=42,
help="RNG seed placeholder for compatibility with experiment_runner. Currently unused."
)
parser.add_argument(
"--max_tokens_per_model",
type=str,
default="",
help="Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
)
parser.add_argument(
"--prompts_dir",
type=str,
default=None,
help="Path to the directory containing prompt files. Defaults to the packaged prompts directory."
)
return parser.parse_args()
@ -179,8 +197,17 @@ async def main():
if is_resuming:
try:
# When resuming, we load the state and also the config from the last saved phase.
# We will continue to use THAT config, ignoring the current script's args.
game, agents, game_history = load_game_state(run_dir, game_file_name, args.resume_from_phase)
game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, args.resume_from_phase)
if loaded_run_config:
# Use the saved config, but allow current CLI args to override control-flow parameters
run_config = loaded_run_config
run_config.run_dir = args.run_dir
run_config.critical_state_analysis_dir = args.critical_state_analysis_dir
run_config.resume_from_phase = args.resume_from_phase
run_config.end_at_phase = args.end_at_phase
# If prompts_dir is specified now, it overrides the saved one.
if args.prompts_dir is not None:
run_config.prompts_dir = args.prompts_dir
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
except (FileNotFoundError, ValueError) as e:
logger.error(f"Could not resume game: {e}. Starting a new game instead.")
@ -204,7 +231,7 @@ async def main():
if year_int > run_config.max_year:
logger.info(f"Reached max year {run_config.max_year}, stopping simulation.")
break
if run_config.end_at_phase and game.phases and game.phases[-1].name == run_config.end_at_phase:
if run_config.end_at_phase and current_phase == run_config.end_at_phase:
logger.info(f"Reached end phase {run_config.end_at_phase}, stopping simulation.")
break
@ -289,7 +316,7 @@ async def main():
# Diary Consolidation
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
consolidation_tasks = [
run_diary_consolidation(agent, game, llm_log_file_path)
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
]
if consolidation_tasks:

View file

@ -17,3 +17,5 @@ anthropic
google-genai
json-repair
together
matplotlib
seaborn

5
run.sh
View file

@ -2,6 +2,7 @@
python3 lm_game.py \
--max_year 1901 \
--num_negotiation_rounds 0 \
--num_negotiation_rounds 1 \
--models "openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17" \
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000 \
--prompts_dir "ai_diplomacy/prompts"