mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
add experiment runner
This commit is contained in:
parent
7d50b31e34
commit
840c6b0ad9
16 changed files with 943 additions and 80 deletions
85
README.md
85
README.md
|
|
@ -236,13 +236,92 @@ python lm_game.py --max_year 1910 --planning_phase --num_negotiation_rounds 2
|
|||
# Custom model assignment (order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY)
|
||||
python lm_game.py --models "claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022"
|
||||
|
||||
# Output to specific file
|
||||
python lm_game.py --output results/my_game.json
|
||||
|
||||
# Run until game completion or specific year
|
||||
python lm_game.py --num_negotiation_rounds 2 --planning_phase
|
||||
|
||||
# Write all artefacts to a chosen directory (auto-resumes if it already exists)
|
||||
python lm_game.py --run_dir results/game_run_001
|
||||
|
||||
# Resume an interrupted game from a specific phase
|
||||
python lm_game.py --run_dir results/game_run_001 --resume_from_phase S1902M
|
||||
|
||||
# Critical-state analysis: resume from an existing run but save new results elsewhere
|
||||
python lm_game.py \
|
||||
--run_dir results/game_run_001 \
|
||||
--critical_state_analysis_dir results/critical_analysis_001 \
|
||||
--resume_from_phase F1903M
|
||||
|
||||
# End the simulation after a particular phase regardless of remaining years
|
||||
python lm_game.py --run_dir results/game_run_002 --end_at_phase F1905M
|
||||
|
||||
# Set the global max_tokens generation limit
|
||||
python lm_game.py --run_dir results/game_run_003 --max_tokens 8000
|
||||
|
||||
# Per-model token limits (AU,EN,FR,GE,IT,RU,TR)
|
||||
python lm_game.py --run_dir results/game_run_004 \
|
||||
--max_tokens_per_model "8000,8000,16000,8000,8000,16000,8000"
|
||||
|
||||
# Use a custom prompts directory
|
||||
python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants
|
||||
```
|
||||
|
||||
### Running Batch Experiments with **`experiment_runner.py`**
|
||||
|
||||
`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify.
|
||||
All flags that belong to **`lm_game.py`** can be passed straight through; the runner validates them and forwards them unchanged to every game instance.
|
||||
|
||||
---
|
||||
|
||||
#### Examples
|
||||
|
||||
```bash
|
||||
# Run 10 independent games (iterations) in parallel, using a custom prompts dir
|
||||
# and a single model (GPT-4o) for all seven powers.
|
||||
python3 experiment_runner.py \
|
||||
--experiment_dir "results/exp001" \
|
||||
--iterations 10 \
|
||||
--parallel 10 \
|
||||
--max_year 1905 \
|
||||
--num_negotiation_rounds 0 \
|
||||
--prompts_dir "ai_diplomacy/prompts" \
|
||||
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
|
||||
|
||||
|
||||
# Critical-state analysis: resume every run from W1901A (taken from an existing
|
||||
# base run) and stop after S1902M. Two analysis modules will be executed:
|
||||
# • summary → aggregated results & scores
|
||||
# • critical_state → before/after snapshots around the critical phase
|
||||
python3 experiment_runner.py \
|
||||
--experiment_dir "results/exp002" \
|
||||
--iterations 10 \
|
||||
--parallel 10 \
|
||||
--resume_from_phase W1901A \
|
||||
--end_at_phase S1902M \
|
||||
--num_negotiation_rounds 0 \
|
||||
--critical_state_base_run "results/test1" \
|
||||
--prompts_dir "ai_diplomacy/prompts" \
|
||||
--analysis_modules "summary,critical_state" \
|
||||
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
|
||||
```
|
||||
|
||||
*(Any other `lm_game.py` flags—`--planning_phase`, `--max_tokens`, etc.—can be added exactly where you’d use them on a single-game run.)*
|
||||
|
||||
---
|
||||
|
||||
#### Experiment-runner–specific arguments
|
||||
|
||||
| Flag | Type / Default | Description |
|
||||
| --------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--experiment_dir` **(required)** | `Path` | Root folder for the experiment; sub-folders `runs/` and `analysis/` are managed automatically. Re-running with the same directory will **resume** existing runs and regenerate analysis. |
|
||||
| `--iterations` | `int`, default `1` | How many individual games to launch for this experiment. |
|
||||
| `--parallel` | `int`, default `1` | Max number of games to execute concurrently (uses a process pool). |
|
||||
| `--analysis_modules` | `str`, default `"summary"` | Comma-separated list of analysis modules to run after all games finish. Modules are imported from `experiment_runner.analysis.<name>` and must provide `run(experiment_dir, ctx)`. |
|
||||
| `--critical_state_base_run` | `Path`, optional | Path to an **existing** `run_dir` produced by a previous `lm_game` run. Each iteration resumes from that snapshot; new artefacts are written under the current `experiment_dir`. |
|
||||
| `--seed_base` | `int`, default `42` | Base random seed. Run *ɪ* receives seed = `seed_base + ɪ`, enabling reproducible batches. |
|
||||
|
||||
*(All other command-line flags belong to `lm_game.py` and are forwarded unchanged.)*
|
||||
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a `.env` file with your API keys:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,17 @@ ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUS
|
|||
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
||||
|
||||
# == New: Helper function to load prompt files reliably ==
|
||||
def _load_prompt_file(filename: str) -> Optional[str]:
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
||||
"""Loads a prompt template from the prompts directory."""
|
||||
try:
|
||||
# Construct path relative to this file's location
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
# Construct path relative to this file's location
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
|
|
@ -50,6 +54,7 @@ class DiplomacyAgent:
|
|||
client: BaseModelClient,
|
||||
initial_goals: Optional[List[str]] = None,
|
||||
initial_relationships: Optional[Dict[str, str]] = None,
|
||||
prompts_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the DiplomacyAgent.
|
||||
|
|
@ -60,12 +65,14 @@ class DiplomacyAgent:
|
|||
initial_goals: An optional list of initial strategic goals.
|
||||
initial_relationships: An optional dictionary mapping other power names to
|
||||
relationship statuses (e.g., 'ALLY', 'ENEMY', 'NEUTRAL').
|
||||
prompts_dir: Optional path to the prompts directory.
|
||||
"""
|
||||
if power_name not in ALL_POWERS:
|
||||
raise ValueError(f"Invalid power name: {power_name}. Must be one of {ALL_POWERS}")
|
||||
|
||||
self.power_name: str = power_name
|
||||
self.client: BaseModelClient = client
|
||||
self.prompts_dir: Optional[str] = prompts_dir
|
||||
# Initialize goals as empty list, will be populated by initialize_agent_state
|
||||
self.goals: List[str] = initial_goals if initial_goals is not None else []
|
||||
# Initialize relationships to Neutral if not provided
|
||||
|
|
@ -85,16 +92,21 @@ class DiplomacyAgent:
|
|||
# Get the directory containing the current file (agent.py)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Construct path relative to the current file's directory
|
||||
prompts_dir = os.path.join(current_dir, "prompts")
|
||||
power_prompt_filename = os.path.join(prompts_dir, f"{power_name.lower()}_system_prompt.txt")
|
||||
default_prompt_filename = os.path.join(prompts_dir, "system_prompt.txt")
|
||||
default_prompts_path = os.path.join(current_dir, "prompts")
|
||||
power_prompt_filename = f"{power_name.lower()}_system_prompt.txt"
|
||||
default_prompt_filename = "system_prompt.txt"
|
||||
|
||||
system_prompt_content = load_prompt(power_prompt_filename)
|
||||
# Use the provided prompts_dir if available, otherwise use the default
|
||||
prompts_path_to_use = self.prompts_dir if self.prompts_dir else default_prompts_path
|
||||
|
||||
power_prompt_filepath = os.path.join(prompts_path_to_use, power_prompt_filename)
|
||||
default_prompt_filepath = os.path.join(prompts_path_to_use, default_prompt_filename)
|
||||
|
||||
system_prompt_content = load_prompt(power_prompt_filepath, prompts_dir=self.prompts_dir)
|
||||
|
||||
if not system_prompt_content:
|
||||
logger.warning(f"Power-specific prompt '{power_prompt_filename}' not found or empty. Loading default system prompt.")
|
||||
# system_prompt_content = load_prompt("system_prompt.txt")
|
||||
system_prompt_content = load_prompt(default_prompt_filename)
|
||||
logger.warning(f"Power-specific prompt '{power_prompt_filepath}' not found or empty. Loading default system prompt.")
|
||||
system_prompt_content = load_prompt(default_prompt_filepath, prompts_dir=self.prompts_dir)
|
||||
else:
|
||||
logger.info(f"Loaded power-specific system prompt for {power_name}.")
|
||||
# ----------------------------------------------------
|
||||
|
|
@ -416,7 +428,7 @@ class DiplomacyAgent:
|
|||
|
||||
try:
|
||||
# Load the template file but safely preprocess it first
|
||||
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt')
|
||||
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||
if not prompt_template_content:
|
||||
logger.error(f"[{self.power_name}] Could not load negotiation_diary_prompt.txt. Skipping diary entry.")
|
||||
success_status = "Failure: Prompt file not loaded"
|
||||
|
|
@ -611,7 +623,7 @@ class DiplomacyAgent:
|
|||
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
||||
|
||||
# Load the template but we'll use it carefully with string interpolation
|
||||
prompt_template = _load_prompt_file('order_diary_prompt.txt')
|
||||
prompt_template = _load_prompt_file('order_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{self.power_name}] Could not load order_diary_prompt.txt. Skipping diary entry.")
|
||||
return
|
||||
|
|
@ -756,7 +768,7 @@ class DiplomacyAgent:
|
|||
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
||||
|
||||
# Load the template
|
||||
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt')
|
||||
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
||||
return
|
||||
|
|
@ -859,7 +871,7 @@ class DiplomacyAgent:
|
|||
|
||||
try:
|
||||
# 1. Construct the prompt using the dedicated state update prompt file
|
||||
prompt_template = _load_prompt_file('state_update_prompt.txt')
|
||||
prompt_template = _load_prompt_file('state_update_prompt.txt', prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{power_name}] Could not load state_update_prompt.txt. Skipping state update.")
|
||||
return
|
||||
|
|
@ -893,6 +905,7 @@ class DiplomacyAgent:
|
|||
agent_goals=self.goals,
|
||||
agent_relationships=self.relationships,
|
||||
agent_private_diary=formatted_diary, # Pass formatted diary
|
||||
prompts_dir=self.prompts_dir,
|
||||
)
|
||||
|
||||
# Add previous phase summary to the information provided to the LLM
|
||||
|
|
|
|||
|
|
@ -44,10 +44,11 @@ class BaseModelClient:
|
|||
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
self.model_name = model_name
|
||||
self.prompts_dir = prompts_dir
|
||||
# Load a default initially, can be overwritten by set_system_prompt
|
||||
self.system_prompt = load_prompt("system_prompt.txt")
|
||||
self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
|
||||
self.max_tokens = 16000 # default unless overridden
|
||||
|
||||
def set_system_prompt(self, content: str):
|
||||
|
|
@ -97,6 +98,7 @@ class BaseModelClient:
|
|||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary_str=agent_private_diary_str,
|
||||
prompts_dir=self.prompts_dir,
|
||||
)
|
||||
|
||||
raw_response = ""
|
||||
|
|
@ -423,7 +425,7 @@ class BaseModelClient:
|
|||
agent_private_diary_str: Optional[str] = None, # Added
|
||||
) -> str:
|
||||
|
||||
instructions = load_prompt("planning_instructions.txt")
|
||||
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||
|
||||
context = self.build_context_prompt(
|
||||
game,
|
||||
|
|
@ -434,6 +436,7 @@ class BaseModelClient:
|
|||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||
prompts_dir=self.prompts_dir,
|
||||
)
|
||||
|
||||
return context + "\n\n" + instructions
|
||||
|
|
@ -451,7 +454,7 @@ class BaseModelClient:
|
|||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary_str: Optional[str] = None, # Added
|
||||
) -> str:
|
||||
instructions = load_prompt("conversation_instructions.txt")
|
||||
instructions = load_prompt("conversation_instructions.txt", prompts_dir=self.prompts_dir)
|
||||
|
||||
context = build_context_prompt(
|
||||
game,
|
||||
|
|
@ -462,6 +465,7 @@ class BaseModelClient:
|
|||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||
prompts_dir=self.prompts_dir,
|
||||
)
|
||||
|
||||
# Get recent messages targeting this power to prioritize responses
|
||||
|
|
@ -699,7 +703,7 @@ class BaseModelClient:
|
|||
"""
|
||||
logger.info(f"Client generating strategic plan for {power_name}...")
|
||||
|
||||
planning_instructions = load_prompt("planning_instructions.txt")
|
||||
planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||
if not planning_instructions:
|
||||
logger.error("Could not load planning_instructions.txt! Cannot generate plan.")
|
||||
return "Error: Planning instructions not found."
|
||||
|
|
@ -718,6 +722,7 @@ class BaseModelClient:
|
|||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||
prompts_dir=self.prompts_dir,
|
||||
)
|
||||
|
||||
full_prompt = f"{context_prompt}\n\n{planning_instructions}"
|
||||
|
|
@ -772,8 +777,8 @@ class OpenAIClient(BaseModelClient):
|
|||
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
|
|
@ -819,8 +824,8 @@ class ClaudeClient(BaseModelClient):
|
|||
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
|
|
@ -861,8 +866,8 @@ class GeminiClient(BaseModelClient):
|
|||
For 'gemini-1.5-flash' or other Google Generative AI models.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
# Configure and get the model (corrected initialization)
|
||||
api_key = os.environ.get("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
|
|
@ -905,8 +910,8 @@ class DeepSeekClient(BaseModelClient):
|
|||
For DeepSeek R1 'deepseek-reasoner'
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||
self.client = AsyncDeepSeekOpenAI(
|
||||
api_key=self.api_key,
|
||||
|
|
@ -961,8 +966,8 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
This client makes direct HTTP requests to the v1/responses endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
|
|
@ -1068,14 +1073,14 @@ class OpenRouterClient(BaseModelClient):
|
|||
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openrouter/quasar-alpha"):
|
||||
def __init__(self, model_name: str = "openrouter/quasar-alpha", prompts_dir: Optional[str] = None):
|
||||
# Allow specifying just the model identifier or the full path
|
||||
if not model_name.startswith("openrouter/") and "/" not in model_name:
|
||||
model_name = f"openrouter/{model_name}"
|
||||
if model_name.startswith("openrouter-"):
|
||||
model_name = model_name.replace("openrouter-", "")
|
||||
|
||||
super().__init__(model_name)
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
|
|
@ -1146,8 +1151,8 @@ class TogetherAIClient(BaseModelClient):
|
|||
Model names should be passed without the 'together-' prefix.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name) # model_name here is the actual Together AI model identifier
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir) # model_name here is the actual Together AI model identifier
|
||||
self.api_key = os.environ.get("TOGETHER_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient")
|
||||
|
|
@ -1198,12 +1203,13 @@ class TogetherAIClient(BaseModelClient):
|
|||
##############################################################################
|
||||
|
||||
|
||||
def load_model_client(model_id: str) -> BaseModelClient:
|
||||
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
|
||||
"""
|
||||
Returns the appropriate LLM client for a given model_id string.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier
|
||||
prompts_dir: Optional path to the prompts directory.
|
||||
|
||||
Example usage:
|
||||
client = load_model_client("claude-3-5-sonnet-20241022")
|
||||
|
|
@ -1213,23 +1219,23 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
|||
|
||||
# Check for o3-pro model specifically - it needs the Responses API
|
||||
if lower_id == "o3-pro":
|
||||
return OpenAIResponsesClient(model_id)
|
||||
return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir)
|
||||
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
||||
elif model_id.startswith("together-"):
|
||||
actual_model_name = model_id.split("together-", 1)[1]
|
||||
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
|
||||
return TogetherAIClient(actual_model_name)
|
||||
return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir)
|
||||
elif "openrouter" in model_id.lower() or "/" in model_id: # More general check for OpenRouterClient(model_id)
|
||||
return OpenRouterClient(model_id)
|
||||
return OpenRouterClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "claude" in lower_id:
|
||||
return ClaudeClient(model_id)
|
||||
return ClaudeClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "gemini" in lower_id:
|
||||
return GeminiClient(model_id)
|
||||
return GeminiClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "deepseek" in lower_id:
|
||||
return DeepSeekClient(model_id)
|
||||
return DeepSeekClient(model_id, prompts_dir=prompts_dir)
|
||||
else:
|
||||
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
||||
return OpenAIClient(model_id)
|
||||
return OpenAIClient(model_id, prompts_dir=prompts_dir)
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
|
@ -1249,4 +1255,4 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
|||
or msg["recipient"] == power_name
|
||||
):
|
||||
visible.append(msg)
|
||||
return visible # already in chronological order if appended that way
|
||||
return visible # already in chronological order if appended that way
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# ai_diplomacy/diary_logic.py
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .utils import run_llm_and_log, log_llm_response
|
||||
|
||||
|
|
@ -11,13 +11,17 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _load_prompt_file(filename: str) -> str | None:
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||
import os
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
|
|
@ -29,6 +33,7 @@ async def run_diary_consolidation(
|
|||
game: "Game",
|
||||
log_file_path: str,
|
||||
entries_to_keep_unsummarized: int = 15,
|
||||
prompts_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Consolidate older diary entries while keeping recent ones.
|
||||
|
|
@ -93,7 +98,7 @@ async def run_diary_consolidation(
|
|||
)
|
||||
return
|
||||
|
||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt")
|
||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||
|
|
|
|||
|
|
@ -31,16 +31,17 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|||
"private_diary": agent.private_diary,
|
||||
}
|
||||
|
||||
def deserialize_agent(agent_data: dict) -> DiplomacyAgent:
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent:
|
||||
"""Recreates an agent object from a dictionary."""
|
||||
client = load_model_client(agent_data["model_id"])
|
||||
client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir)
|
||||
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
|
||||
|
||||
agent = DiplomacyAgent(
|
||||
power_name=agent_data["power_name"],
|
||||
client=client,
|
||||
initial_goals=agent_data.get("goals", []),
|
||||
initial_relationships=agent_data.get("relationships", None)
|
||||
initial_relationships=agent_data.get("relationships", None),
|
||||
prompts_dir=prompts_dir
|
||||
)
|
||||
# Restore the diary.
|
||||
agent.full_private_diary = agent_data.get("full_private_diary", [])
|
||||
|
|
@ -179,7 +180,7 @@ def load_game_state(
|
|||
run_dir: str,
|
||||
game_file_name: str,
|
||||
resume_from_phase: Optional[str] = None
|
||||
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Namespace]:
|
||||
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
|
||||
"""Loads and reconstructs the game state from a saved game file."""
|
||||
game_file_path = os.path.join(run_dir, game_file_name)
|
||||
if not os.path.exists(game_file_path):
|
||||
|
|
@ -189,6 +190,14 @@ def load_game_state(
|
|||
with open(game_file_path, 'r') as f:
|
||||
saved_game_data = json.load(f)
|
||||
|
||||
# Find the latest config saved in the file
|
||||
run_config = None
|
||||
if saved_game_data.get("phases"):
|
||||
for phase in reversed(saved_game_data["phases"]):
|
||||
if "config" in phase:
|
||||
run_config = Namespace(**phase["config"])
|
||||
logger.info(f"Loaded run configuration from phase {phase['name']}.")
|
||||
break
|
||||
|
||||
# If resuming, find the specified phase and truncate the data after it
|
||||
if resume_from_phase:
|
||||
|
|
@ -231,8 +240,6 @@ def load_game_state(
|
|||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||
agents = {} # Will be created by the main loop
|
||||
game_history = GameHistory()
|
||||
# No config to load, will use current run's args
|
||||
run_config = None
|
||||
else:
|
||||
# We save the game state up to & including the current (uncompleted) phase.
|
||||
# So we need to grab the agent state from the previous (completed) phase.
|
||||
|
|
@ -245,8 +252,9 @@ def load_game_state(
|
|||
agents = {}
|
||||
if 'state_agents' in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None
|
||||
for power_name, agent_data in last_phase_data['state_agents'].items():
|
||||
agents[power_name] = deserialize_agent(agent_data)
|
||||
agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config)
|
||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||
else:
|
||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||
|
|
@ -267,7 +275,7 @@ def load_game_state(
|
|||
logger.info("Game history rebuilt.")
|
||||
|
||||
|
||||
return game, agents, game_history
|
||||
return game, agents, game_history, run_config
|
||||
|
||||
|
||||
async def initialize_new_game(
|
||||
|
|
@ -308,12 +316,12 @@ async def initialize_new_game(
|
|||
for power_name, model_id in game.power_model_map.items():
|
||||
if not game.powers[power_name].is_eliminated():
|
||||
try:
|
||||
client = load_model_client(model_id)
|
||||
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
|
||||
client.max_tokens = model_max_tokens[power_name]
|
||||
agent = DiplomacyAgent(power_name=power_name, client=client)
|
||||
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
|
||||
agents[power_name] = agent
|
||||
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path))
|
||||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# ai_diplomacy/initialization.py
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
# Forward declaration for type hinting, actual imports in function if complex
|
||||
if False: # TYPE_CHECKING
|
||||
|
|
@ -18,7 +19,8 @@ async def initialize_agent_state_ext(
|
|||
agent: 'DiplomacyAgent',
|
||||
game: 'Game',
|
||||
game_history: 'GameHistory',
|
||||
log_file_path: str
|
||||
log_file_path: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
):
|
||||
"""Uses the LLM to set initial goals and relationships for the agent."""
|
||||
power_name = agent.power_name
|
||||
|
|
@ -56,7 +58,8 @@ async def initialize_agent_state_ext(
|
|||
game_history=game_history,
|
||||
agent_goals=None,
|
||||
agent_relationships=None,
|
||||
agent_private_diary=formatted_diary,
|
||||
agent_private_diary=formatted_diary,
|
||||
prompts_dir=prompts_dir,
|
||||
)
|
||||
full_prompt = initial_prompt + "\n\n" + context
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def build_context_prompt(
|
|||
agent_goals: Optional[List[str]] = None,
|
||||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary: Optional[str] = None,
|
||||
prompts_dir: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Builds the detailed context part of the prompt.
|
||||
|
||||
|
|
@ -35,11 +36,12 @@ def build_context_prompt(
|
|||
agent_goals: Optional list of agent's goals.
|
||||
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
||||
agent_private_diary: Optional string of agent's private diary.
|
||||
prompts_dir: Optional path to the prompts directory.
|
||||
|
||||
Returns:
|
||||
A string containing the formatted context.
|
||||
"""
|
||||
context_template = load_prompt("context_prompt.txt")
|
||||
context_template = load_prompt("context_prompt.txt", prompts_dir=prompts_dir)
|
||||
|
||||
# === Agent State Debug Logging ===
|
||||
if agent_goals:
|
||||
|
|
@ -112,6 +114,7 @@ def construct_order_generation_prompt(
|
|||
agent_goals: Optional[List[str]] = None,
|
||||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary_str: Optional[str] = None,
|
||||
prompts_dir: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Constructs the final prompt for order generation.
|
||||
|
||||
|
|
@ -125,13 +128,14 @@ def construct_order_generation_prompt(
|
|||
agent_goals: Optional list of agent's goals.
|
||||
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
||||
agent_private_diary_str: Optional string of agent's private diary.
|
||||
prompts_dir: Optional path to the prompts directory.
|
||||
|
||||
Returns:
|
||||
A string containing the complete prompt for the LLM.
|
||||
"""
|
||||
# Load prompts
|
||||
_ = load_prompt("few_shot_example.txt") # Loaded but not used, as per original logic
|
||||
instructions = load_prompt("order_instructions.txt")
|
||||
_ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic
|
||||
instructions = load_prompt("order_instructions.txt", prompts_dir=prompts_dir)
|
||||
|
||||
# Build the context prompt
|
||||
context = build_context_prompt(
|
||||
|
|
@ -143,7 +147,8 @@ def construct_order_generation_prompt(
|
|||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary=agent_private_diary_str,
|
||||
prompts_dir=prompts_dir,
|
||||
)
|
||||
|
||||
final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions
|
||||
return final_prompt
|
||||
return final_prompt
|
||||
|
|
@ -293,11 +293,14 @@ def normalize_and_compare_orders(
|
|||
|
||||
|
||||
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
||||
def load_prompt(filename: str) -> str:
|
||||
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||
"""Helper to load prompt text from file"""
|
||||
# Assuming execution from the root or that the path resolves correctly
|
||||
# Consider using absolute paths or pkg_resources if needed for robustness
|
||||
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
|
||||
if prompts_dir:
|
||||
prompt_path = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
# Default behavior: relative to this file's location in the 'prompts' subdir
|
||||
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r", encoding='utf-8') as f: # Added encoding
|
||||
return f.read().strip()
|
||||
|
|
|
|||
474
experiment_runner.py
Normal file
474
experiment_runner.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment orchestration for Diplomacy self-play.
|
||||
Launches many `lm_game` runs in parallel, captures their artefacts,
|
||||
and executes a pluggable post-analysis pipeline.
|
||||
|
||||
Run `python experiment_runner.py --help` for CLI details.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Iterable, List
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Logging #
|
||||
# --------------------------------------------------------------------------- #
|
||||
LOG_FMT = "%(asctime)s [%(levelname)s] %(name)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%H:%M:%S")
|
||||
log = logging.getLogger("experiment_runner")
|
||||
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Flag definitions – full, un-shortened help strings #
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument(
|
||||
"--experiment_dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help=(
|
||||
"Directory that will hold all experiment artefacts. "
|
||||
"A 'runs/' sub-folder is created for individual game runs and an "
|
||||
"'analysis/' folder for aggregated outputs. Must be writable."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"Number of lm_game instances to launch for this experiment. "
|
||||
"Each instance gets its own sub-directory under runs/."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--parallel",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"Maximum number of game instances to run concurrently. "
|
||||
"Uses a ProcessPoolExecutor under the hood."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--analysis_modules",
|
||||
type=str,
|
||||
default="summary",
|
||||
help=(
|
||||
"Comma-separated list of analysis module names to execute after all "
|
||||
"runs finish. Modules are imported from "
|
||||
"'experiment_runner.analysis.<name>' and must expose "
|
||||
"run(experiment_dir: Path, ctx: dict)."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--critical_state_base_run",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to an *existing* run directory produced by a previous lm_game "
|
||||
"execution. When supplied, every iteration resumes from that "
|
||||
"snapshot using lm_game's --critical_state_analysis_dir mechanism."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--seed_base",
|
||||
type=int,
|
||||
default=42,
|
||||
help=(
|
||||
"Base RNG seed. Run i will receive seed = seed_base + i. "
|
||||
"Forwarded to lm_game via its --seed flag (you must have added that "
|
||||
"flag to lm_game for deterministic behaviour)."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
||||
# ---- all flags copied verbatim from lm_game.parse_arguments() ----
|
||||
p.add_argument(
|
||||
"--resume_from_phase",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Phase to resume from (e.g., 'S1902M'). Requires --run_dir. "
|
||||
"IMPORTANT: This option clears any existing phase results ahead of "
|
||||
"& including the specified resume phase."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--end_at_phase",
|
||||
type=str,
|
||||
default="",
|
||||
help="Phase to end the simulation after (e.g., 'F1905M').",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_year",
|
||||
type=int,
|
||||
default=1910, # Increased default in lm_game
|
||||
help="Maximum year to simulate. The game will stop once this year is reached.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_negotiation_rounds",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of negotiation rounds per phase.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of model names to assign to powers in order. "
|
||||
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--planning_phase",
|
||||
action="store_true",
|
||||
help="Enable the planning phase for each power to set strategic directives.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_tokens",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Maximum number of new tokens to generate per LLM call (default: 16000).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_tokens_per_model",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, "
|
||||
"FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompts_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to the directory containing prompt files. "
|
||||
"Defaults to the packaged prompts directory."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# One combined parser for banner printing #
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
def _build_full_parser() -> argparse.ArgumentParser:
|
||||
fp = argparse.ArgumentParser(
|
||||
prog="experiment_runner.py",
|
||||
formatter_class=lambda prog: argparse.RawTextHelpFormatter(
|
||||
prog, max_help_position=45
|
||||
),
|
||||
description=(
|
||||
"Batch-runner for Diplomacy self-play experiments. "
|
||||
"All lm_game flags are accepted here as-is; they are validated "
|
||||
"before any game runs start."
|
||||
),
|
||||
)
|
||||
_add_experiment_flags(fp)
|
||||
_add_lm_game_flags(fp)
|
||||
return fp
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Robust parsing that always shows *full* help on error #
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
def _parse_cli() -> tuple[argparse.Namespace, list[str], argparse.Namespace]:
|
||||
full_parser = _build_full_parser()
|
||||
|
||||
# Show full banner when no args
|
||||
if len(sys.argv) == 1:
|
||||
full_parser.print_help(sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
# Show full banner on explicit help
|
||||
if any(tok in ("-h", "--help") for tok in sys.argv[1:]):
|
||||
full_parser.print_help(sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
# Sub-parsers for separating experiment vs game flags
|
||||
class _ErrParser(argparse.ArgumentParser):
|
||||
def error(self, msg):
|
||||
full_parser.print_help(sys.stderr)
|
||||
self.exit(2, f"{self.prog}: error: {msg}\n")
|
||||
|
||||
exp_parser = _ErrParser(add_help=False)
|
||||
game_parser = _ErrParser(add_help=False)
|
||||
_add_experiment_flags(exp_parser)
|
||||
_add_lm_game_flags(game_parser)
|
||||
|
||||
# Split argv tokens by flag ownership
|
||||
argv = sys.argv[1:]
|
||||
exp_flag_set = {opt for a in exp_parser._actions for opt in a.option_strings}
|
||||
|
||||
exp_tok, game_tok, i = [], [], 0
|
||||
while i < len(argv):
|
||||
tok = argv[i]
|
||||
if tok in exp_flag_set:
|
||||
exp_tok.append(tok)
|
||||
action = exp_parser._option_string_actions[tok]
|
||||
needs_val = (
|
||||
action.nargs is None # default: exactly one value
|
||||
or (isinstance(action.nargs, int) and action.nargs > 0)
|
||||
or action.nargs in ("+", "*", "?") # variable-length cases
|
||||
)
|
||||
if needs_val:
|
||||
exp_tok.append(argv[i + 1])
|
||||
i += 2
|
||||
else: # store_true / store_false
|
||||
i += 1
|
||||
|
||||
else:
|
||||
game_tok.append(tok)
|
||||
i += 1
|
||||
|
||||
exp_args = exp_parser.parse_args(exp_tok)
|
||||
game_args = game_parser.parse_args(game_tok)
|
||||
return exp_args, game_tok, game_args
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
_RunInfo = collections.namedtuple(
|
||||
"_RunInfo", "index run_dir seed cmd_line returncode elapsed_s"
|
||||
)
|
||||
|
||||
|
||||
def _mk_run_dir(exp_dir: Path, idx: int) -> Path:
|
||||
run_dir = exp_dir / "runs" / f"run_{idx:05d}"
|
||||
# Just ensure it exists; don't raise if it already does.
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
return run_dir
|
||||
|
||||
|
||||
def _dump_seed(seed: int, run_dir: Path) -> None:
|
||||
seed_file = run_dir / "seed.txt"
|
||||
if not seed_file.exists():
|
||||
seed_file.write_text(str(seed))
|
||||
|
||||
|
||||
def _build_cmd(
|
||||
lm_game_script: Path,
|
||||
base_cli: List[str],
|
||||
run_dir: Path,
|
||||
seed: int,
|
||||
critical_base: Path | None,
|
||||
resume_from_phase: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list suitable for subprocess.run([...]).
|
||||
"""
|
||||
cmd = [sys.executable, str(lm_game_script)]
|
||||
|
||||
# Forward user CLI
|
||||
cmd.extend(base_cli)
|
||||
|
||||
# Per-run mandatory overrides
|
||||
cmd.extend(["--run_dir", str(run_dir)])
|
||||
cmd.extend(["--seed", str(seed)]) # you may need to add a --seed flag to lm_game
|
||||
|
||||
# Critical-state mode
|
||||
if critical_base:
|
||||
cmd.extend([
|
||||
"--critical_state_analysis_dir", str(run_dir),
|
||||
"--run_dir", str(critical_base) # base run dir (already completed)
|
||||
])
|
||||
if resume_from_phase:
|
||||
cmd.extend(["--resume_from_phase", resume_from_phase])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def _launch_one(args) -> _RunInfo:
|
||||
"""
|
||||
Worker executed inside a ProcessPool; runs one game via subprocess.
|
||||
"""
|
||||
(
|
||||
idx,
|
||||
lm_game_script,
|
||||
base_cli,
|
||||
run_dir,
|
||||
seed,
|
||||
critical_base,
|
||||
resume_phase,
|
||||
) = args
|
||||
|
||||
cmd = _build_cmd(
|
||||
lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase
|
||||
)
|
||||
start = time.perf_counter()
|
||||
log.debug("Run %05d: CMD = %s", idx, " ".join(cmd))
|
||||
|
||||
# Write out full command for traceability
|
||||
(run_dir / "command.txt").write_text(" ".join(cmd))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
(run_dir / "console.log").write_text(result.stdout)
|
||||
rc = result.returncode
|
||||
except Exception as exc: # noqa: broad-except
|
||||
(run_dir / "console.log").write_text(f"Exception launching run:\n{exc}\n")
|
||||
rc = 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return _RunInfo(idx, run_dir, seed, " ".join(cmd), rc, elapsed)
|
||||
|
||||
|
||||
def _load_analysis_fns(module_names: Iterable[str]):
|
||||
"""
|
||||
Dynamically import analysis modules.
|
||||
Each module must expose `run(experiment_dir: Path, cfg: dict)`.
|
||||
"""
|
||||
for name in module_names:
|
||||
mod_name = f"experiment_runner.analysis.{name.strip()}"
|
||||
try:
|
||||
mod = importlib.import_module(mod_name)
|
||||
except ModuleNotFoundError as e:
|
||||
log.warning("Analysis module %s not found (%s) – skipping", mod_name, e)
|
||||
continue
|
||||
|
||||
if not hasattr(mod, "run"):
|
||||
log.warning("%s has no `run()` function – skipping", mod_name)
|
||||
continue
|
||||
yield mod.run
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Main driver #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def main() -> None:
|
||||
exp_args, leftover_cli, game_args = _parse_cli()
|
||||
|
||||
exp_dir: Path = exp_args.experiment_dir.expanduser().resolve()
|
||||
if exp_dir.exists():
|
||||
log.info("Appending to existing experiment: %s", exp_dir)
|
||||
exp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Persist experiment-level config
|
||||
cfg_path = exp_dir / "config.json"
|
||||
if not cfg_path.exists(): # ← new guard
|
||||
with cfg_path.open("w", encoding="utf-8") as fh:
|
||||
json.dump(
|
||||
{"experiment": vars(exp_args),
|
||||
"lm_game": vars(game_args),
|
||||
"forwarded_cli": leftover_cli},
|
||||
fh, indent=2, default=str,
|
||||
)
|
||||
log.info("Config saved to %s", cfg_path)
|
||||
else:
|
||||
log.info("Config already exists – leaving unchanged")
|
||||
|
||||
|
||||
log.info("Config saved to %s", cfg_path)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Launch games #
|
||||
# ------------------------------------------------------------------ #
|
||||
lm_game_script = Path(__file__).parent / "lm_game.py"
|
||||
if not lm_game_script.exists():
|
||||
log.error("lm_game.py not found at %s – abort", lm_game_script)
|
||||
sys.exit(1)
|
||||
|
||||
run_args = []
|
||||
for i in range(exp_args.iterations):
|
||||
run_dir = _mk_run_dir(exp_dir, i)
|
||||
seed = exp_args.seed_base + i
|
||||
_dump_seed(seed, run_dir)
|
||||
|
||||
run_args.append(
|
||||
(
|
||||
i, lm_game_script, leftover_cli, run_dir, seed,
|
||||
exp_args.critical_state_base_run,
|
||||
game_args.resume_from_phase,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
log.info(
|
||||
"Launching %d runs (max %d parallel, critical_state=%s)",
|
||||
exp_args.iterations,
|
||||
exp_args.parallel,
|
||||
bool(exp_args.critical_state_base_run),
|
||||
)
|
||||
|
||||
runs_meta: list[_RunInfo] = []
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=exp_args.parallel,
|
||||
mp_context=mp.get_context("spawn"),
|
||||
) as pool:
|
||||
for res in pool.map(_launch_one, run_args):
|
||||
runs_meta.append(res)
|
||||
status = "OK" if res.returncode == 0 else f"RC={res.returncode}"
|
||||
log.info(
|
||||
"run_%05d finished in %.1fs %s", res.index, res.elapsed_s, status
|
||||
)
|
||||
|
||||
# Persist per-run status summary
|
||||
summary_path = exp_dir / "runs_summary.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as fh:
|
||||
json.dump([res._asdict() for res in runs_meta], fh, indent=2, default=str)
|
||||
log.info("Run summary written → %s", summary_path)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Post-analysis pipeline #
|
||||
# ------------------------------------------------------------------ #
|
||||
mods = list(_load_analysis_fns(exp_args.analysis_modules.split(",")))
|
||||
if not mods:
|
||||
log.warning("No analysis modules loaded – done.")
|
||||
return
|
||||
|
||||
analysis_root = exp_dir / "analysis"
|
||||
if analysis_root.exists():
|
||||
shutil.rmtree(analysis_root) # ← wipes old outputs
|
||||
analysis_root.mkdir(exist_ok=True)
|
||||
|
||||
# Collect common context
|
||||
ctx: dict = {
|
||||
"exp_dir": str(exp_dir),
|
||||
"runs": [str(r.run_dir) for r in runs_meta],
|
||||
"critical_state_base": str(exp_args.critical_state_base_run or ""),
|
||||
"resume_from_phase": game_args.resume_from_phase,
|
||||
}
|
||||
|
||||
for fn in mods:
|
||||
name = fn.__module__.rsplit(".", 1)[-1]
|
||||
log.info("Running analysis module: %s", name)
|
||||
try:
|
||||
fn(exp_dir, ctx)
|
||||
log.info("✓ %s complete", name)
|
||||
except Exception as exc: # noqa: broad-except
|
||||
log.exception("Analysis module %s failed: %s", name, exc)
|
||||
|
||||
log.info("Experiment finished – artefacts in %s", exp_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
experiment_runner/__init__.py
Normal file
0
experiment_runner/__init__.py
Normal file
0
experiment_runner/analysis/__init__.py
Normal file
0
experiment_runner/analysis/__init__.py
Normal file
60
experiment_runner/analysis/critical_state.py
Normal file
60
experiment_runner/analysis/critical_state.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Extracts the board state *before* and *after* a critical phase
|
||||
for every critical-analysis run produced by experiment_runner.
|
||||
|
||||
Each run must have:
|
||||
• lmvsgame.json – containing a phase named ctx["resume_from_phase"]
|
||||
|
||||
Outputs live in analysis/critical_state/<run_name>_{before|after}.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _phase_by_name(phases: List[dict], name: str) -> dict | None:
|
||||
for ph in phases:
|
||||
if ph["state"]["name"] == name:
|
||||
return ph
|
||||
return None
|
||||
|
||||
|
||||
def run(exp_dir: Path, ctx: Dict) -> None:
|
||||
resume_phase = ctx.get("resume_from_phase")
|
||||
if not resume_phase:
|
||||
log.info("critical_state: --resume_from_phase not supplied – skipping")
|
||||
return
|
||||
|
||||
out_dir = exp_dir / "analysis" / "critical_state"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for run_dir in (exp_dir / "runs").iterdir():
|
||||
game_json = run_dir / "lmvsgame.json"
|
||||
if not game_json.exists():
|
||||
continue
|
||||
|
||||
with game_json.open("r") as fh:
|
||||
game = json.load(fh)
|
||||
phases = game.get("phases", [])
|
||||
if not phases:
|
||||
continue
|
||||
|
||||
before = _phase_by_name(phases, resume_phase)
|
||||
after = phases[-1] if phases else None
|
||||
if before is None or after is None:
|
||||
log.warning("Run %s missing expected phases – skipped", run_dir.name)
|
||||
continue
|
||||
|
||||
(out_dir / f"{run_dir.name}_before.json").write_text(
|
||||
json.dumps(before, indent=2)
|
||||
)
|
||||
(out_dir / f"{run_dir.name}_after.json").write_text(
|
||||
json.dumps(after, indent=2)
|
||||
)
|
||||
|
||||
log.info("critical_state: snapshots written to %s", out_dir)
|
||||
177
experiment_runner/analysis/summary.py
Normal file
177
experiment_runner/analysis/summary.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Aggregates results across all runs and writes:
|
||||
|
||||
• analysis/aggregated_results.csv
|
||||
• analysis/score_summary_by_power.csv
|
||||
• analysis/results_summary.png (if matplotlib available)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Helpers copied (and lightly cleaned) from run_games.py #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _extract_diplomacy_results(game_json: Path) -> Dict[str, Dict[str, Any]]:
|
||||
with game_json.open("r", encoding="utf-8") as fh:
|
||||
gd = json.load(fh)
|
||||
|
||||
phases = gd.get("phases", [])
|
||||
if not phases:
|
||||
raise ValueError("no phases")
|
||||
|
||||
first_state = phases[0]["state"]
|
||||
last_state = phases[-1]["state"]
|
||||
|
||||
powers = (
|
||||
list(first_state.get("homes", {}).keys())
|
||||
or list(first_state.get("centers", {}).keys())
|
||||
)
|
||||
if not powers:
|
||||
raise ValueError("cannot determine powers")
|
||||
|
||||
scs_to_win = 18
|
||||
solo_winner = next(
|
||||
(p for p, sc in last_state["centers"].items() if len(sc) >= scs_to_win), None
|
||||
)
|
||||
|
||||
results: Dict[str, Dict[str, Any]] = {}
|
||||
for p in powers:
|
||||
sc_count = len(last_state["centers"].get(p, []))
|
||||
units = len(last_state["units"].get(p, []))
|
||||
|
||||
if solo_winner:
|
||||
if p == solo_winner:
|
||||
outcome, cat = f"Won solo", "Solo Win"
|
||||
else:
|
||||
outcome, cat = f"Lost to {solo_winner}", "Loss"
|
||||
elif sc_count == 0 and units == 0:
|
||||
outcome, cat = "Eliminated", "Eliminated"
|
||||
else:
|
||||
outcome, cat = "Ongoing/Draw", "Ongoing/Abandoned/Draw"
|
||||
|
||||
results[p] = {
|
||||
"OutcomeCategory": cat,
|
||||
"StatusDetail": outcome,
|
||||
"SupplyCenters": sc_count,
|
||||
"LastPhase": last_state["name"],
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
# simplistic "Diplobench" scoring from previous discussion ------------------ #
|
||||
def _year(name: str) -> int | None:
|
||||
m = re.search(r"(\d{4})", name)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
|
||||
def _score_game(game_json: Path) -> Dict[str, int]:
|
||||
with open(game_json, "r") as fh:
|
||||
game = json.load(fh)
|
||||
phases = game["phases"]
|
||||
if not phases:
|
||||
return {}
|
||||
|
||||
start = _year(phases[0]["state"]["name"])
|
||||
end_year = _year(phases[-1]["state"]["name"]) or start
|
||||
max_turns = (end_year - start + 1) if start is not None else len(phases)
|
||||
|
||||
last_state = phases[-1]["state"]
|
||||
solo_winner = next(
|
||||
(p for p, scs in last_state["centers"].items() if len(scs) >= 18), None
|
||||
)
|
||||
|
||||
elim_turn: Dict[str, int | None] = {}
|
||||
for p in last_state["centers"].keys():
|
||||
e_turn = None
|
||||
for idx, ph in enumerate(phases):
|
||||
if not ph["state"]["centers"].get(p, []):
|
||||
yr = _year(ph["state"]["name"]) or 0
|
||||
e_turn = (yr - start + 1) if start is not None else idx + 1
|
||||
break
|
||||
elim_turn[p] = e_turn
|
||||
|
||||
scores: Dict[str, int] = {}
|
||||
for p, scs in last_state["centers"].items():
|
||||
if p == solo_winner:
|
||||
win_turn = (end_year - start + 1) if start is not None else len(phases)
|
||||
scores[p] = max_turns + 17 + (max_turns - win_turn)
|
||||
elif solo_winner:
|
||||
# losers in a solo game
|
||||
yr_win = _year(phases[-1]["state"]["name"]) or end_year
|
||||
turn_win = (yr_win - start + 1) if start is not None else len(phases)
|
||||
scores[p] = turn_win
|
||||
else:
|
||||
if elim_turn[p] is None:
|
||||
scores[p] = max_turns + len(scs)
|
||||
else:
|
||||
scores[p] = elim_turn[p]
|
||||
return scores
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public entry point #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def run(exp_dir: Path, ctx: dict): # pylint: disable=unused-argument
|
||||
analysis_dir = exp_dir / "analysis"
|
||||
analysis_dir.mkdir(exist_ok=True)
|
||||
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for run_dir in (exp_dir / "runs").iterdir():
|
||||
game_json = run_dir / "lmvsgame.json"
|
||||
if not game_json.exists():
|
||||
continue
|
||||
|
||||
gid = run_dir.name
|
||||
try:
|
||||
res = _extract_diplomacy_results(game_json)
|
||||
except Exception as e: # noqa: broad-except
|
||||
log.warning("Could not parse %s (%s)", game_json, e)
|
||||
continue
|
||||
|
||||
scores = _score_game(game_json)
|
||||
for pwr, info in res.items():
|
||||
out = {"GameID": gid, "Power": pwr, **info, "Score": scores.get(pwr, None)}
|
||||
rows.append(out)
|
||||
|
||||
if not rows:
|
||||
log.warning("summary: no parsable runs found")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
out_csv = analysis_dir / "aggregated_results.csv"
|
||||
df.to_csv(out_csv, index=False)
|
||||
|
||||
summary = (
|
||||
df.groupby("Power")["Score"]
|
||||
.agg(["mean", "median", "count"])
|
||||
.reset_index()
|
||||
.rename(columns={"count": "n"})
|
||||
)
|
||||
summary.to_csv(analysis_dir / "score_summary_by_power.csv", index=False)
|
||||
|
||||
log.info("summary: wrote %s rows to %s", len(df), out_csv)
|
||||
|
||||
# Optional charts
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set_style("whitegrid")
|
||||
plt.figure(figsize=(10, 7))
|
||||
sns.boxplot(x="Power", y="SupplyCenters", data=df, palette="pastel")
|
||||
plt.title("Supply-center distribution")
|
||||
plt.savefig(analysis_dir / "results_summary.png", dpi=150)
|
||||
plt.close()
|
||||
log.info("summary: chart saved")
|
||||
except Exception as e: # noqa: broad-except
|
||||
log.debug("Chart generation skipped (%s)", e)
|
||||
35
lm_game.py
35
lm_game.py
|
|
@ -55,6 +55,12 @@ def parse_arguments():
|
|||
default="",
|
||||
help="Directory for results. If it exists, the game resumes. If not, it's created. Defaults to a new timestamped directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", # alias for back compatibility
|
||||
dest="run_dir", # write to the same variable as --run_dir
|
||||
type=str,
|
||||
help=argparse.SUPPRESS # hides it from `--help`
|
||||
)
|
||||
parser.add_argument(
|
||||
"--critical_state_analysis_dir",
|
||||
type=str,
|
||||
|
|
@ -105,12 +111,24 @@ def parse_arguments():
|
|||
default=16000,
|
||||
help="Maximum number of new tokens to generate per LLM call (default: 16000)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed_base",
|
||||
type=int,
|
||||
default=42,
|
||||
help="RNG seed placeholder for compatibility with experiment_runner. Currently unused."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens_per_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompts_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the directory containing prompt files. Defaults to the packaged prompts directory."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -179,8 +197,17 @@ async def main():
|
|||
if is_resuming:
|
||||
try:
|
||||
# When resuming, we load the state and also the config from the last saved phase.
|
||||
# We will continue to use THAT config, ignoring the current script's args.
|
||||
game, agents, game_history = load_game_state(run_dir, game_file_name, args.resume_from_phase)
|
||||
game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, args.resume_from_phase)
|
||||
if loaded_run_config:
|
||||
# Use the saved config, but allow current CLI args to override control-flow parameters
|
||||
run_config = loaded_run_config
|
||||
run_config.run_dir = args.run_dir
|
||||
run_config.critical_state_analysis_dir = args.critical_state_analysis_dir
|
||||
run_config.resume_from_phase = args.resume_from_phase
|
||||
run_config.end_at_phase = args.end_at_phase
|
||||
# If prompts_dir is specified now, it overrides the saved one.
|
||||
if args.prompts_dir is not None:
|
||||
run_config.prompts_dir = args.prompts_dir
|
||||
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.error(f"Could not resume game: {e}. Starting a new game instead.")
|
||||
|
|
@ -204,7 +231,7 @@ async def main():
|
|||
if year_int > run_config.max_year:
|
||||
logger.info(f"Reached max year {run_config.max_year}, stopping simulation.")
|
||||
break
|
||||
if run_config.end_at_phase and game.phases and game.phases[-1].name == run_config.end_at_phase:
|
||||
if run_config.end_at_phase and current_phase == run_config.end_at_phase:
|
||||
logger.info(f"Reached end phase {run_config.end_at_phase}, stopping simulation.")
|
||||
break
|
||||
|
||||
|
|
@ -289,7 +316,7 @@ async def main():
|
|||
# Diary Consolidation
|
||||
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
||||
consolidation_tasks = [
|
||||
run_diary_consolidation(agent, game, llm_log_file_path)
|
||||
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
|
||||
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
|
||||
]
|
||||
if consolidation_tasks:
|
||||
|
|
|
|||
|
|
@ -17,3 +17,5 @@ anthropic
|
|||
google-genai
|
||||
json-repair
|
||||
together
|
||||
matplotlib
|
||||
seaborn
|
||||
5
run.sh
5
run.sh
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
python3 lm_game.py \
|
||||
--max_year 1901 \
|
||||
--num_negotiation_rounds 0 \
|
||||
--num_negotiation_rounds 1 \
|
||||
--models "openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17" \
|
||||
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000
|
||||
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000 \
|
||||
--prompts_dir "ai_diplomacy/prompts"
|
||||
Loading…
Add table
Add a link
Reference in a new issue