mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Merge pull request #42 from EveryInc/experiment-runner
Experiment runner
This commit is contained in:
commit
617ccdb1a0
16 changed files with 943 additions and 80 deletions
85
README.md
85
README.md
|
|
@ -236,13 +236,92 @@ python lm_game.py --max_year 1910 --planning_phase --num_negotiation_rounds 2
|
||||||
# Custom model assignment (order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY)
|
# Custom model assignment (order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY)
|
||||||
python lm_game.py --models "claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022"
|
python lm_game.py --models "claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022,gpt-4o,claude-3-5-sonnet-20241022"
|
||||||
|
|
||||||
# Output to specific file
|
|
||||||
python lm_game.py --output results/my_game.json
|
|
||||||
|
|
||||||
# Run until game completion or specific year
|
# Run until game completion or specific year
|
||||||
python lm_game.py --num_negotiation_rounds 2 --planning_phase
|
python lm_game.py --num_negotiation_rounds 2 --planning_phase
|
||||||
|
|
||||||
|
# Write all artefacts to a chosen directory (auto-resumes if it already exists)
|
||||||
|
python lm_game.py --run_dir results/game_run_001
|
||||||
|
|
||||||
|
# Resume an interrupted game from a specific phase
|
||||||
|
python lm_game.py --run_dir results/game_run_001 --resume_from_phase S1902M
|
||||||
|
|
||||||
|
# Critical-state analysis: resume from an existing run but save new results elsewhere
|
||||||
|
python lm_game.py \
|
||||||
|
--run_dir results/game_run_001 \
|
||||||
|
--critical_state_analysis_dir results/critical_analysis_001 \
|
||||||
|
--resume_from_phase F1903M
|
||||||
|
|
||||||
|
# End the simulation after a particular phase regardless of remaining years
|
||||||
|
python lm_game.py --run_dir results/game_run_002 --end_at_phase F1905M
|
||||||
|
|
||||||
|
# Set the global max_tokens generation limit
|
||||||
|
python lm_game.py --run_dir results/game_run_003 --max_tokens 8000
|
||||||
|
|
||||||
|
# Per-model token limits (AU,EN,FR,GE,IT,RU,TR)
|
||||||
|
python lm_game.py --run_dir results/game_run_004 \
|
||||||
|
--max_tokens_per_model "8000,8000,16000,8000,8000,16000,8000"
|
||||||
|
|
||||||
|
# Use a custom prompts directory
|
||||||
|
python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Running Batch Experiments with **`experiment_runner.py`**
|
||||||
|
|
||||||
|
`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify.
|
||||||
|
All flags that belong to **`lm_game.py`** can be passed straight through; the runner validates them and forwards them unchanged to every game instance.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run 10 independent games (iterations) in parallel, using a custom prompts dir
|
||||||
|
# and a single model (GPT-4o) for all seven powers.
|
||||||
|
python3 experiment_runner.py \
|
||||||
|
--experiment_dir "results/exp001" \
|
||||||
|
--iterations 10 \
|
||||||
|
--parallel 10 \
|
||||||
|
--max_year 1905 \
|
||||||
|
--num_negotiation_rounds 0 \
|
||||||
|
--prompts_dir "ai_diplomacy/prompts" \
|
||||||
|
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
# Critical-state analysis: resume every run from W1901A (taken from an existing
|
||||||
|
# base run) and stop after S1902M. Two analysis modules will be executed:
|
||||||
|
# • summary → aggregated results & scores
|
||||||
|
# • critical_state → before/after snapshots around the critical phase
|
||||||
|
python3 experiment_runner.py \
|
||||||
|
--experiment_dir "results/exp002" \
|
||||||
|
--iterations 10 \
|
||||||
|
--parallel 10 \
|
||||||
|
--resume_from_phase W1901A \
|
||||||
|
--end_at_phase S1902M \
|
||||||
|
--num_negotiation_rounds 0 \
|
||||||
|
--critical_state_base_run "results/test1" \
|
||||||
|
--prompts_dir "ai_diplomacy/prompts" \
|
||||||
|
--analysis_modules "summary,critical_state" \
|
||||||
|
--models "gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o,gpt-4o"
|
||||||
|
```
|
||||||
|
|
||||||
|
*(Any other `lm_game.py` flags—`--planning_phase`, `--max_tokens`, etc.—can be added exactly where you’d use them on a single-game run.)*
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### Experiment-runner–specific arguments
|
||||||
|
|
||||||
|
| Flag | Type / Default | Description |
|
||||||
|
| --------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `--experiment_dir` **(required)** | `Path` | Root folder for the experiment; sub-folders `runs/` and `analysis/` are managed automatically. Re-running with the same directory will **resume** existing runs and regenerate analysis. |
|
||||||
|
| `--iterations` | `int`, default `1` | How many individual games to launch for this experiment. |
|
||||||
|
| `--parallel` | `int`, default `1` | Max number of games to execute concurrently (uses a process pool). |
|
||||||
|
| `--analysis_modules` | `str`, default `"summary"` | Comma-separated list of analysis modules to run after all games finish. Modules are imported from `experiment_runner.analysis.<name>` and must provide `run(experiment_dir, ctx)`. |
|
||||||
|
| `--critical_state_base_run` | `Path`, optional | Path to an **existing** `run_dir` produced by a previous `lm_game` run. Each iteration resumes from that snapshot; new artefacts are written under the current `experiment_dir`. |
|
||||||
|
| `--seed_base` | `int`, default `42` | Base random seed. Run *ɪ* receives seed = `seed_base + ɪ`, enabling reproducible batches. |
|
||||||
|
|
||||||
|
*(All other command-line flags belong to `lm_game.py` and are forwarded unchanged.)*
|
||||||
|
|
||||||
|
|
||||||
### Environment Setup
|
### Environment Setup
|
||||||
|
|
||||||
Create a `.env` file with your API keys:
|
Create a `.env` file with your API keys:
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,17 @@ ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUS
|
||||||
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
||||||
|
|
||||||
# == New: Helper function to load prompt files reliably ==
|
# == New: Helper function to load prompt files reliably ==
|
||||||
def _load_prompt_file(filename: str) -> Optional[str]:
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
||||||
"""Loads a prompt template from the prompts directory."""
|
"""Loads a prompt template from the prompts directory."""
|
||||||
try:
|
try:
|
||||||
# Construct path relative to this file's location
|
if prompts_dir:
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.join(prompts_dir, filename)
|
||||||
prompts_dir = os.path.join(current_dir, 'prompts')
|
else:
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
# Construct path relative to this file's location
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||||
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
|
@ -50,6 +54,7 @@ class DiplomacyAgent:
|
||||||
client: BaseModelClient,
|
client: BaseModelClient,
|
||||||
initial_goals: Optional[List[str]] = None,
|
initial_goals: Optional[List[str]] = None,
|
||||||
initial_relationships: Optional[Dict[str, str]] = None,
|
initial_relationships: Optional[Dict[str, str]] = None,
|
||||||
|
prompts_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the DiplomacyAgent.
|
Initializes the DiplomacyAgent.
|
||||||
|
|
@ -60,12 +65,14 @@ class DiplomacyAgent:
|
||||||
initial_goals: An optional list of initial strategic goals.
|
initial_goals: An optional list of initial strategic goals.
|
||||||
initial_relationships: An optional dictionary mapping other power names to
|
initial_relationships: An optional dictionary mapping other power names to
|
||||||
relationship statuses (e.g., 'ALLY', 'ENEMY', 'NEUTRAL').
|
relationship statuses (e.g., 'ALLY', 'ENEMY', 'NEUTRAL').
|
||||||
|
prompts_dir: Optional path to the prompts directory.
|
||||||
"""
|
"""
|
||||||
if power_name not in ALL_POWERS:
|
if power_name not in ALL_POWERS:
|
||||||
raise ValueError(f"Invalid power name: {power_name}. Must be one of {ALL_POWERS}")
|
raise ValueError(f"Invalid power name: {power_name}. Must be one of {ALL_POWERS}")
|
||||||
|
|
||||||
self.power_name: str = power_name
|
self.power_name: str = power_name
|
||||||
self.client: BaseModelClient = client
|
self.client: BaseModelClient = client
|
||||||
|
self.prompts_dir: Optional[str] = prompts_dir
|
||||||
# Initialize goals as empty list, will be populated by initialize_agent_state
|
# Initialize goals as empty list, will be populated by initialize_agent_state
|
||||||
self.goals: List[str] = initial_goals if initial_goals is not None else []
|
self.goals: List[str] = initial_goals if initial_goals is not None else []
|
||||||
# Initialize relationships to Neutral if not provided
|
# Initialize relationships to Neutral if not provided
|
||||||
|
|
@ -85,16 +92,21 @@ class DiplomacyAgent:
|
||||||
# Get the directory containing the current file (agent.py)
|
# Get the directory containing the current file (agent.py)
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
# Construct path relative to the current file's directory
|
# Construct path relative to the current file's directory
|
||||||
prompts_dir = os.path.join(current_dir, "prompts")
|
default_prompts_path = os.path.join(current_dir, "prompts")
|
||||||
power_prompt_filename = os.path.join(prompts_dir, f"{power_name.lower()}_system_prompt.txt")
|
power_prompt_filename = f"{power_name.lower()}_system_prompt.txt"
|
||||||
default_prompt_filename = os.path.join(prompts_dir, "system_prompt.txt")
|
default_prompt_filename = "system_prompt.txt"
|
||||||
|
|
||||||
system_prompt_content = load_prompt(power_prompt_filename)
|
# Use the provided prompts_dir if available, otherwise use the default
|
||||||
|
prompts_path_to_use = self.prompts_dir if self.prompts_dir else default_prompts_path
|
||||||
|
|
||||||
|
power_prompt_filepath = os.path.join(prompts_path_to_use, power_prompt_filename)
|
||||||
|
default_prompt_filepath = os.path.join(prompts_path_to_use, default_prompt_filename)
|
||||||
|
|
||||||
|
system_prompt_content = load_prompt(power_prompt_filepath, prompts_dir=self.prompts_dir)
|
||||||
|
|
||||||
if not system_prompt_content:
|
if not system_prompt_content:
|
||||||
logger.warning(f"Power-specific prompt '{power_prompt_filename}' not found or empty. Loading default system prompt.")
|
logger.warning(f"Power-specific prompt '{power_prompt_filepath}' not found or empty. Loading default system prompt.")
|
||||||
# system_prompt_content = load_prompt("system_prompt.txt")
|
system_prompt_content = load_prompt(default_prompt_filepath, prompts_dir=self.prompts_dir)
|
||||||
system_prompt_content = load_prompt(default_prompt_filename)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loaded power-specific system prompt for {power_name}.")
|
logger.info(f"Loaded power-specific system prompt for {power_name}.")
|
||||||
# ----------------------------------------------------
|
# ----------------------------------------------------
|
||||||
|
|
@ -416,7 +428,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the template file but safely preprocess it first
|
# Load the template file but safely preprocess it first
|
||||||
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt')
|
prompt_template_content = _load_prompt_file('negotiation_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template_content:
|
if not prompt_template_content:
|
||||||
logger.error(f"[{self.power_name}] Could not load negotiation_diary_prompt.txt. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load negotiation_diary_prompt.txt. Skipping diary entry.")
|
||||||
success_status = "Failure: Prompt file not loaded"
|
success_status = "Failure: Prompt file not loaded"
|
||||||
|
|
@ -611,7 +623,7 @@ class DiplomacyAgent:
|
||||||
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the template but we'll use it carefully with string interpolation
|
# Load the template but we'll use it carefully with string interpolation
|
||||||
prompt_template = _load_prompt_file('order_diary_prompt.txt')
|
prompt_template = _load_prompt_file('order_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load order_diary_prompt.txt. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load order_diary_prompt.txt. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -756,7 +768,7 @@ class DiplomacyAgent:
|
||||||
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the template
|
# Load the template
|
||||||
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt')
|
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -859,7 +871,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Construct the prompt using the dedicated state update prompt file
|
# 1. Construct the prompt using the dedicated state update prompt file
|
||||||
prompt_template = _load_prompt_file('state_update_prompt.txt')
|
prompt_template = _load_prompt_file('state_update_prompt.txt', prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{power_name}] Could not load state_update_prompt.txt. Skipping state update.")
|
logger.error(f"[{power_name}] Could not load state_update_prompt.txt. Skipping state update.")
|
||||||
return
|
return
|
||||||
|
|
@ -893,6 +905,7 @@ class DiplomacyAgent:
|
||||||
agent_goals=self.goals,
|
agent_goals=self.goals,
|
||||||
agent_relationships=self.relationships,
|
agent_relationships=self.relationships,
|
||||||
agent_private_diary=formatted_diary, # Pass formatted diary
|
agent_private_diary=formatted_diary, # Pass formatted diary
|
||||||
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add previous phase summary to the information provided to the LLM
|
# Add previous phase summary to the information provided to the LLM
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,11 @@ class BaseModelClient:
|
||||||
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
|
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.prompts_dir = prompts_dir
|
||||||
# Load a default initially, can be overwritten by set_system_prompt
|
# Load a default initially, can be overwritten by set_system_prompt
|
||||||
self.system_prompt = load_prompt("system_prompt.txt")
|
self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
|
||||||
self.max_tokens = 16000 # default unless overridden
|
self.max_tokens = 16000 # default unless overridden
|
||||||
|
|
||||||
def set_system_prompt(self, content: str):
|
def set_system_prompt(self, content: str):
|
||||||
|
|
@ -97,6 +98,7 @@ class BaseModelClient:
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary_str=agent_private_diary_str,
|
agent_private_diary_str=agent_private_diary_str,
|
||||||
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = ""
|
raw_response = ""
|
||||||
|
|
@ -423,7 +425,7 @@ class BaseModelClient:
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
instructions = load_prompt("planning_instructions.txt")
|
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||||
|
|
||||||
context = self.build_context_prompt(
|
context = self.build_context_prompt(
|
||||||
game,
|
game,
|
||||||
|
|
@ -434,6 +436,7 @@ class BaseModelClient:
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
return context + "\n\n" + instructions
|
return context + "\n\n" + instructions
|
||||||
|
|
@ -451,7 +454,7 @@ class BaseModelClient:
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
instructions = load_prompt("conversation_instructions.txt")
|
instructions = load_prompt("conversation_instructions.txt", prompts_dir=self.prompts_dir)
|
||||||
|
|
||||||
context = build_context_prompt(
|
context = build_context_prompt(
|
||||||
game,
|
game,
|
||||||
|
|
@ -462,6 +465,7 @@ class BaseModelClient:
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get recent messages targeting this power to prioritize responses
|
# Get recent messages targeting this power to prioritize responses
|
||||||
|
|
@ -699,7 +703,7 @@ class BaseModelClient:
|
||||||
"""
|
"""
|
||||||
logger.info(f"Client generating strategic plan for {power_name}...")
|
logger.info(f"Client generating strategic plan for {power_name}...")
|
||||||
|
|
||||||
planning_instructions = load_prompt("planning_instructions.txt")
|
planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||||
if not planning_instructions:
|
if not planning_instructions:
|
||||||
logger.error("Could not load planning_instructions.txt! Cannot generate plan.")
|
logger.error("Could not load planning_instructions.txt! Cannot generate plan.")
|
||||||
return "Error: Planning instructions not found."
|
return "Error: Planning instructions not found."
|
||||||
|
|
@ -718,6 +722,7 @@ class BaseModelClient:
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_prompt = f"{context_prompt}\n\n{planning_instructions}"
|
full_prompt = f"{context_prompt}\n\n{planning_instructions}"
|
||||||
|
|
@ -772,8 +777,8 @@ class OpenAIClient(BaseModelClient):
|
||||||
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
|
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
|
||||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||||
|
|
@ -819,8 +824,8 @@ class ClaudeClient(BaseModelClient):
|
||||||
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
|
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||||
|
|
||||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||||
|
|
@ -861,8 +866,8 @@ class GeminiClient(BaseModelClient):
|
||||||
For 'gemini-1.5-flash' or other Google Generative AI models.
|
For 'gemini-1.5-flash' or other Google Generative AI models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
# Configure and get the model (corrected initialization)
|
# Configure and get the model (corrected initialization)
|
||||||
api_key = os.environ.get("GEMINI_API_KEY")
|
api_key = os.environ.get("GEMINI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
|
|
@ -905,8 +910,8 @@ class DeepSeekClient(BaseModelClient):
|
||||||
For DeepSeek R1 'deepseek-reasoner'
|
For DeepSeek R1 'deepseek-reasoner'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||||
self.client = AsyncDeepSeekOpenAI(
|
self.client = AsyncDeepSeekOpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
|
|
@ -961,8 +966,8 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
This client makes direct HTTP requests to the v1/responses endpoint.
|
This client makes direct HTTP requests to the v1/responses endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||||
|
|
@ -1068,14 +1073,14 @@ class OpenRouterClient(BaseModelClient):
|
||||||
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "openrouter/quasar-alpha"):
|
def __init__(self, model_name: str = "openrouter/quasar-alpha", prompts_dir: Optional[str] = None):
|
||||||
# Allow specifying just the model identifier or the full path
|
# Allow specifying just the model identifier or the full path
|
||||||
if not model_name.startswith("openrouter/") and "/" not in model_name:
|
if not model_name.startswith("openrouter/") and "/" not in model_name:
|
||||||
model_name = f"openrouter/{model_name}"
|
model_name = f"openrouter/{model_name}"
|
||||||
if model_name.startswith("openrouter-"):
|
if model_name.startswith("openrouter-"):
|
||||||
model_name = model_name.replace("openrouter-", "")
|
model_name = model_name.replace("openrouter-", "")
|
||||||
|
|
||||||
super().__init__(model_name)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.api_key = os.environ.get("OPENROUTER_API_KEY")
|
self.api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||||
|
|
@ -1146,8 +1151,8 @@ class TogetherAIClient(BaseModelClient):
|
||||||
Model names should be passed without the 'together-' prefix.
|
Model names should be passed without the 'together-' prefix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name) # model_name here is the actual Together AI model identifier
|
super().__init__(model_name, prompts_dir=prompts_dir) # model_name here is the actual Together AI model identifier
|
||||||
self.api_key = os.environ.get("TOGETHER_API_KEY")
|
self.api_key = os.environ.get("TOGETHER_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient")
|
raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient")
|
||||||
|
|
@ -1198,12 +1203,13 @@ class TogetherAIClient(BaseModelClient):
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
|
|
||||||
def load_model_client(model_id: str) -> BaseModelClient:
|
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
|
||||||
"""
|
"""
|
||||||
Returns the appropriate LLM client for a given model_id string.
|
Returns the appropriate LLM client for a given model_id string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: The model identifier
|
model_id: The model identifier
|
||||||
|
prompts_dir: Optional path to the prompts directory.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
client = load_model_client("claude-3-5-sonnet-20241022")
|
client = load_model_client("claude-3-5-sonnet-20241022")
|
||||||
|
|
@ -1213,23 +1219,23 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
||||||
|
|
||||||
# Check for o3-pro model specifically - it needs the Responses API
|
# Check for o3-pro model specifically - it needs the Responses API
|
||||||
if lower_id == "o3-pro":
|
if lower_id == "o3-pro":
|
||||||
return OpenAIResponsesClient(model_id)
|
return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir)
|
||||||
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
||||||
elif model_id.startswith("together-"):
|
elif model_id.startswith("together-"):
|
||||||
actual_model_name = model_id.split("together-", 1)[1]
|
actual_model_name = model_id.split("together-", 1)[1]
|
||||||
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
|
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
|
||||||
return TogetherAIClient(actual_model_name)
|
return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir)
|
||||||
elif "openrouter" in model_id.lower() or "/" in model_id: # More general check for OpenRouterClient(model_id)
|
elif "openrouter" in model_id.lower() or "/" in model_id: # More general check for OpenRouterClient(model_id)
|
||||||
return OpenRouterClient(model_id)
|
return OpenRouterClient(model_id, prompts_dir=prompts_dir)
|
||||||
elif "claude" in lower_id:
|
elif "claude" in lower_id:
|
||||||
return ClaudeClient(model_id)
|
return ClaudeClient(model_id, prompts_dir=prompts_dir)
|
||||||
elif "gemini" in lower_id:
|
elif "gemini" in lower_id:
|
||||||
return GeminiClient(model_id)
|
return GeminiClient(model_id, prompts_dir=prompts_dir)
|
||||||
elif "deepseek" in lower_id:
|
elif "deepseek" in lower_id:
|
||||||
return DeepSeekClient(model_id)
|
return DeepSeekClient(model_id, prompts_dir=prompts_dir)
|
||||||
else:
|
else:
|
||||||
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
||||||
return OpenAIClient(model_id)
|
return OpenAIClient(model_id, prompts_dir=prompts_dir)
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -1249,4 +1255,4 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
||||||
or msg["recipient"] == power_name
|
or msg["recipient"] == power_name
|
||||||
):
|
):
|
||||||
visible.append(msg)
|
visible.append(msg)
|
||||||
return visible # already in chronological order if appended that way
|
return visible # already in chronological order if appended that way
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# ai_diplomacy/diary_logic.py
|
# ai_diplomacy/diary_logic.py
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from .utils import run_llm_and_log, log_llm_response
|
from .utils import run_llm_and_log, log_llm_response
|
||||||
|
|
||||||
|
|
@ -11,13 +11,17 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _load_prompt_file(filename: str) -> str | None:
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||||
import os
|
import os
|
||||||
try:
|
try:
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
if prompts_dir:
|
||||||
prompts_dir = os.path.join(current_dir, 'prompts')
|
filepath = os.path.join(prompts_dir, filename)
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
else:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||||
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -29,6 +33,7 @@ async def run_diary_consolidation(
|
||||||
game: "Game",
|
game: "Game",
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
entries_to_keep_unsummarized: int = 15,
|
entries_to_keep_unsummarized: int = 15,
|
||||||
|
prompts_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Consolidate older diary entries while keeping recent ones.
|
Consolidate older diary entries while keeping recent ones.
|
||||||
|
|
@ -93,7 +98,7 @@ async def run_diary_consolidation(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt")
|
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||||
|
|
|
||||||
|
|
@ -31,16 +31,17 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||||
"private_diary": agent.private_diary,
|
"private_diary": agent.private_diary,
|
||||||
}
|
}
|
||||||
|
|
||||||
def deserialize_agent(agent_data: dict) -> DiplomacyAgent:
|
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent:
|
||||||
"""Recreates an agent object from a dictionary."""
|
"""Recreates an agent object from a dictionary."""
|
||||||
client = load_model_client(agent_data["model_id"])
|
client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir)
|
||||||
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
|
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
|
||||||
|
|
||||||
agent = DiplomacyAgent(
|
agent = DiplomacyAgent(
|
||||||
power_name=agent_data["power_name"],
|
power_name=agent_data["power_name"],
|
||||||
client=client,
|
client=client,
|
||||||
initial_goals=agent_data.get("goals", []),
|
initial_goals=agent_data.get("goals", []),
|
||||||
initial_relationships=agent_data.get("relationships", None)
|
initial_relationships=agent_data.get("relationships", None),
|
||||||
|
prompts_dir=prompts_dir
|
||||||
)
|
)
|
||||||
# Restore the diary.
|
# Restore the diary.
|
||||||
agent.full_private_diary = agent_data.get("full_private_diary", [])
|
agent.full_private_diary = agent_data.get("full_private_diary", [])
|
||||||
|
|
@ -179,7 +180,7 @@ def load_game_state(
|
||||||
run_dir: str,
|
run_dir: str,
|
||||||
game_file_name: str,
|
game_file_name: str,
|
||||||
resume_from_phase: Optional[str] = None
|
resume_from_phase: Optional[str] = None
|
||||||
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Namespace]:
|
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
|
||||||
"""Loads and reconstructs the game state from a saved game file."""
|
"""Loads and reconstructs the game state from a saved game file."""
|
||||||
game_file_path = os.path.join(run_dir, game_file_name)
|
game_file_path = os.path.join(run_dir, game_file_name)
|
||||||
if not os.path.exists(game_file_path):
|
if not os.path.exists(game_file_path):
|
||||||
|
|
@ -189,6 +190,14 @@ def load_game_state(
|
||||||
with open(game_file_path, 'r') as f:
|
with open(game_file_path, 'r') as f:
|
||||||
saved_game_data = json.load(f)
|
saved_game_data = json.load(f)
|
||||||
|
|
||||||
|
# Find the latest config saved in the file
|
||||||
|
run_config = None
|
||||||
|
if saved_game_data.get("phases"):
|
||||||
|
for phase in reversed(saved_game_data["phases"]):
|
||||||
|
if "config" in phase:
|
||||||
|
run_config = Namespace(**phase["config"])
|
||||||
|
logger.info(f"Loaded run configuration from phase {phase['name']}.")
|
||||||
|
break
|
||||||
|
|
||||||
# If resuming, find the specified phase and truncate the data after it
|
# If resuming, find the specified phase and truncate the data after it
|
||||||
if resume_from_phase:
|
if resume_from_phase:
|
||||||
|
|
@ -231,8 +240,6 @@ def load_game_state(
|
||||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||||
agents = {} # Will be created by the main loop
|
agents = {} # Will be created by the main loop
|
||||||
game_history = GameHistory()
|
game_history = GameHistory()
|
||||||
# No config to load, will use current run's args
|
|
||||||
run_config = None
|
|
||||||
else:
|
else:
|
||||||
# We save the game state up to & including the current (uncompleted) phase.
|
# We save the game state up to & including the current (uncompleted) phase.
|
||||||
# So we need to grab the agent state from the previous (completed) phase.
|
# So we need to grab the agent state from the previous (completed) phase.
|
||||||
|
|
@ -245,8 +252,9 @@ def load_game_state(
|
||||||
agents = {}
|
agents = {}
|
||||||
if 'state_agents' in last_phase_data:
|
if 'state_agents' in last_phase_data:
|
||||||
logger.info("Rebuilding agents from saved state...")
|
logger.info("Rebuilding agents from saved state...")
|
||||||
|
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None
|
||||||
for power_name, agent_data in last_phase_data['state_agents'].items():
|
for power_name, agent_data in last_phase_data['state_agents'].items():
|
||||||
agents[power_name] = deserialize_agent(agent_data)
|
agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config)
|
||||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||||
|
|
@ -267,7 +275,7 @@ def load_game_state(
|
||||||
logger.info("Game history rebuilt.")
|
logger.info("Game history rebuilt.")
|
||||||
|
|
||||||
|
|
||||||
return game, agents, game_history
|
return game, agents, game_history, run_config
|
||||||
|
|
||||||
|
|
||||||
async def initialize_new_game(
|
async def initialize_new_game(
|
||||||
|
|
@ -308,12 +316,12 @@ async def initialize_new_game(
|
||||||
for power_name, model_id in game.power_model_map.items():
|
for power_name, model_id in game.power_model_map.items():
|
||||||
if not game.powers[power_name].is_eliminated():
|
if not game.powers[power_name].is_eliminated():
|
||||||
try:
|
try:
|
||||||
client = load_model_client(model_id)
|
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
|
||||||
client.max_tokens = model_max_tokens[power_name]
|
client.max_tokens = model_max_tokens[power_name]
|
||||||
agent = DiplomacyAgent(power_name=power_name, client=client)
|
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
|
||||||
agents[power_name] = agent
|
agents[power_name] = agent
|
||||||
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path))
|
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# ai_diplomacy/initialization.py
|
# ai_diplomacy/initialization.py
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Forward declaration for type hinting, actual imports in function if complex
|
# Forward declaration for type hinting, actual imports in function if complex
|
||||||
if False: # TYPE_CHECKING
|
if False: # TYPE_CHECKING
|
||||||
|
|
@ -18,7 +19,8 @@ async def initialize_agent_state_ext(
|
||||||
agent: 'DiplomacyAgent',
|
agent: 'DiplomacyAgent',
|
||||||
game: 'Game',
|
game: 'Game',
|
||||||
game_history: 'GameHistory',
|
game_history: 'GameHistory',
|
||||||
log_file_path: str
|
log_file_path: str,
|
||||||
|
prompts_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Uses the LLM to set initial goals and relationships for the agent."""
|
"""Uses the LLM to set initial goals and relationships for the agent."""
|
||||||
power_name = agent.power_name
|
power_name = agent.power_name
|
||||||
|
|
@ -56,7 +58,8 @@ async def initialize_agent_state_ext(
|
||||||
game_history=game_history,
|
game_history=game_history,
|
||||||
agent_goals=None,
|
agent_goals=None,
|
||||||
agent_relationships=None,
|
agent_relationships=None,
|
||||||
agent_private_diary=formatted_diary,
|
agent_private_diary=formatted_diary,
|
||||||
|
prompts_dir=prompts_dir,
|
||||||
)
|
)
|
||||||
full_prompt = initial_prompt + "\n\n" + context
|
full_prompt = initial_prompt + "\n\n" + context
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ def build_context_prompt(
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary: Optional[str] = None,
|
agent_private_diary: Optional[str] = None,
|
||||||
|
prompts_dir: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Builds the detailed context part of the prompt.
|
"""Builds the detailed context part of the prompt.
|
||||||
|
|
||||||
|
|
@ -35,11 +36,12 @@ def build_context_prompt(
|
||||||
agent_goals: Optional list of agent's goals.
|
agent_goals: Optional list of agent's goals.
|
||||||
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
||||||
agent_private_diary: Optional string of agent's private diary.
|
agent_private_diary: Optional string of agent's private diary.
|
||||||
|
prompts_dir: Optional path to the prompts directory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string containing the formatted context.
|
A string containing the formatted context.
|
||||||
"""
|
"""
|
||||||
context_template = load_prompt("context_prompt.txt")
|
context_template = load_prompt("context_prompt.txt", prompts_dir=prompts_dir)
|
||||||
|
|
||||||
# === Agent State Debug Logging ===
|
# === Agent State Debug Logging ===
|
||||||
if agent_goals:
|
if agent_goals:
|
||||||
|
|
@ -112,6 +114,7 @@ def construct_order_generation_prompt(
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None,
|
agent_private_diary_str: Optional[str] = None,
|
||||||
|
prompts_dir: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Constructs the final prompt for order generation.
|
"""Constructs the final prompt for order generation.
|
||||||
|
|
||||||
|
|
@ -125,13 +128,14 @@ def construct_order_generation_prompt(
|
||||||
agent_goals: Optional list of agent's goals.
|
agent_goals: Optional list of agent's goals.
|
||||||
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
agent_relationships: Optional dictionary of agent's relationships with other powers.
|
||||||
agent_private_diary_str: Optional string of agent's private diary.
|
agent_private_diary_str: Optional string of agent's private diary.
|
||||||
|
prompts_dir: Optional path to the prompts directory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string containing the complete prompt for the LLM.
|
A string containing the complete prompt for the LLM.
|
||||||
"""
|
"""
|
||||||
# Load prompts
|
# Load prompts
|
||||||
_ = load_prompt("few_shot_example.txt") # Loaded but not used, as per original logic
|
_ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic
|
||||||
instructions = load_prompt("order_instructions.txt")
|
instructions = load_prompt("order_instructions.txt", prompts_dir=prompts_dir)
|
||||||
|
|
||||||
# Build the context prompt
|
# Build the context prompt
|
||||||
context = build_context_prompt(
|
context = build_context_prompt(
|
||||||
|
|
@ -143,7 +147,8 @@ def construct_order_generation_prompt(
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str,
|
agent_private_diary=agent_private_diary_str,
|
||||||
|
prompts_dir=prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions
|
final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
@ -293,11 +293,14 @@ def normalize_and_compare_orders(
|
||||||
|
|
||||||
|
|
||||||
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
||||||
def load_prompt(filename: str) -> str:
|
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||||
"""Helper to load prompt text from file"""
|
"""Helper to load prompt text from file"""
|
||||||
# Assuming execution from the root or that the path resolves correctly
|
if prompts_dir:
|
||||||
# Consider using absolute paths or pkg_resources if needed for robustness
|
prompt_path = os.path.join(prompts_dir, filename)
|
||||||
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
|
else:
|
||||||
|
# Default behavior: relative to this file's location in the 'prompts' subdir
|
||||||
|
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(prompt_path, "r", encoding='utf-8') as f: # Added encoding
|
with open(prompt_path, "r", encoding='utf-8') as f: # Added encoding
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
|
|
|
||||||
474
experiment_runner.py
Normal file
474
experiment_runner.py
Normal file
|
|
@ -0,0 +1,474 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Experiment orchestration for Diplomacy self-play.
|
||||||
|
Launches many `lm_game` runs in parallel, captures their artefacts,
|
||||||
|
and executes a pluggable post-analysis pipeline.
|
||||||
|
|
||||||
|
Run `python experiment_runner.py --help` for CLI details.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import collections
|
||||||
|
import concurrent.futures
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Logging #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
LOG_FMT = "%(asctime)s [%(levelname)s] %(name)s - %(message)s"
|
||||||
|
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%H:%M:%S")
|
||||||
|
log = logging.getLogger("experiment_runner")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Flag definitions – full, un-shortened help strings #
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
|
||||||
|
p.add_argument(
|
||||||
|
"--experiment_dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help=(
|
||||||
|
"Directory that will hold all experiment artefacts. "
|
||||||
|
"A 'runs/' sub-folder is created for individual game runs and an "
|
||||||
|
"'analysis/' folder for aggregated outputs. Must be writable."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--iterations",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=(
|
||||||
|
"Number of lm_game instances to launch for this experiment. "
|
||||||
|
"Each instance gets its own sub-directory under runs/."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--parallel",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=(
|
||||||
|
"Maximum number of game instances to run concurrently. "
|
||||||
|
"Uses a ProcessPoolExecutor under the hood."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--analysis_modules",
|
||||||
|
type=str,
|
||||||
|
default="summary",
|
||||||
|
help=(
|
||||||
|
"Comma-separated list of analysis module names to execute after all "
|
||||||
|
"runs finish. Modules are imported from "
|
||||||
|
"'experiment_runner.analysis.<name>' and must expose "
|
||||||
|
"run(experiment_dir: Path, ctx: dict)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--critical_state_base_run",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to an *existing* run directory produced by a previous lm_game "
|
||||||
|
"execution. When supplied, every iteration resumes from that "
|
||||||
|
"snapshot using lm_game's --critical_state_analysis_dir mechanism."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--seed_base",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help=(
|
||||||
|
"Base RNG seed. Run i will receive seed = seed_base + i. "
|
||||||
|
"Forwarded to lm_game via its --seed flag (you must have added that "
|
||||||
|
"flag to lm_game for deterministic behaviour)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
||||||
|
# ---- all flags copied verbatim from lm_game.parse_arguments() ----
|
||||||
|
p.add_argument(
|
||||||
|
"--resume_from_phase",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=(
|
||||||
|
"Phase to resume from (e.g., 'S1902M'). Requires --run_dir. "
|
||||||
|
"IMPORTANT: This option clears any existing phase results ahead of "
|
||||||
|
"& including the specified resume phase."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--end_at_phase",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Phase to end the simulation after (e.g., 'F1905M').",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_year",
|
||||||
|
type=int,
|
||||||
|
default=1910, # Increased default in lm_game
|
||||||
|
help="Maximum year to simulate. The game will stop once this year is reached.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--num_negotiation_rounds",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of negotiation rounds per phase.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--models",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=(
|
||||||
|
"Comma-separated list of model names to assign to powers in order. "
|
||||||
|
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--planning_phase",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable the planning phase for each power to set strategic directives.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_tokens",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="Maximum number of new tokens to generate per LLM call (default: 16000).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_tokens_per_model",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=(
|
||||||
|
"Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, "
|
||||||
|
"FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--prompts_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to the directory containing prompt files. "
|
||||||
|
"Defaults to the packaged prompts directory."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
# One combined parser for banner printing #
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
def _build_full_parser() -> argparse.ArgumentParser:
|
||||||
|
fp = argparse.ArgumentParser(
|
||||||
|
prog="experiment_runner.py",
|
||||||
|
formatter_class=lambda prog: argparse.RawTextHelpFormatter(
|
||||||
|
prog, max_help_position=45
|
||||||
|
),
|
||||||
|
description=(
|
||||||
|
"Batch-runner for Diplomacy self-play experiments. "
|
||||||
|
"All lm_game flags are accepted here as-is; they are validated "
|
||||||
|
"before any game runs start."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_add_experiment_flags(fp)
|
||||||
|
_add_lm_game_flags(fp)
|
||||||
|
return fp
|
||||||
|
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Robust parsing that always shows *full* help on error #
|
||||||
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
def _parse_cli() -> tuple[argparse.Namespace, list[str], argparse.Namespace]:
|
||||||
|
full_parser = _build_full_parser()
|
||||||
|
|
||||||
|
# Show full banner when no args
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
full_parser.print_help(sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
# Show full banner on explicit help
|
||||||
|
if any(tok in ("-h", "--help") for tok in sys.argv[1:]):
|
||||||
|
full_parser.print_help(sys.stderr)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Sub-parsers for separating experiment vs game flags
|
||||||
|
class _ErrParser(argparse.ArgumentParser):
|
||||||
|
def error(self, msg):
|
||||||
|
full_parser.print_help(sys.stderr)
|
||||||
|
self.exit(2, f"{self.prog}: error: {msg}\n")
|
||||||
|
|
||||||
|
exp_parser = _ErrParser(add_help=False)
|
||||||
|
game_parser = _ErrParser(add_help=False)
|
||||||
|
_add_experiment_flags(exp_parser)
|
||||||
|
_add_lm_game_flags(game_parser)
|
||||||
|
|
||||||
|
# Split argv tokens by flag ownership
|
||||||
|
argv = sys.argv[1:]
|
||||||
|
exp_flag_set = {opt for a in exp_parser._actions for opt in a.option_strings}
|
||||||
|
|
||||||
|
exp_tok, game_tok, i = [], [], 0
|
||||||
|
while i < len(argv):
|
||||||
|
tok = argv[i]
|
||||||
|
if tok in exp_flag_set:
|
||||||
|
exp_tok.append(tok)
|
||||||
|
action = exp_parser._option_string_actions[tok]
|
||||||
|
needs_val = (
|
||||||
|
action.nargs is None # default: exactly one value
|
||||||
|
or (isinstance(action.nargs, int) and action.nargs > 0)
|
||||||
|
or action.nargs in ("+", "*", "?") # variable-length cases
|
||||||
|
)
|
||||||
|
if needs_val:
|
||||||
|
exp_tok.append(argv[i + 1])
|
||||||
|
i += 2
|
||||||
|
else: # store_true / store_false
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
game_tok.append(tok)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
exp_args = exp_parser.parse_args(exp_tok)
|
||||||
|
game_args = game_parser.parse_args(game_tok)
|
||||||
|
return exp_args, game_tok, game_args
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Helpers #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
_RunInfo = collections.namedtuple(
|
||||||
|
"_RunInfo", "index run_dir seed cmd_line returncode elapsed_s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_run_dir(exp_dir: Path, idx: int) -> Path:
|
||||||
|
run_dir = exp_dir / "runs" / f"run_{idx:05d}"
|
||||||
|
# Just ensure it exists; don't raise if it already does.
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return run_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_seed(seed: int, run_dir: Path) -> None:
|
||||||
|
seed_file = run_dir / "seed.txt"
|
||||||
|
if not seed_file.exists():
|
||||||
|
seed_file.write_text(str(seed))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cmd(
|
||||||
|
lm_game_script: Path,
|
||||||
|
base_cli: List[str],
|
||||||
|
run_dir: Path,
|
||||||
|
seed: int,
|
||||||
|
critical_base: Path | None,
|
||||||
|
resume_from_phase: str,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns a list suitable for subprocess.run([...]).
|
||||||
|
"""
|
||||||
|
cmd = [sys.executable, str(lm_game_script)]
|
||||||
|
|
||||||
|
# Forward user CLI
|
||||||
|
cmd.extend(base_cli)
|
||||||
|
|
||||||
|
# Per-run mandatory overrides
|
||||||
|
cmd.extend(["--run_dir", str(run_dir)])
|
||||||
|
cmd.extend(["--seed", str(seed)]) # you may need to add a --seed flag to lm_game
|
||||||
|
|
||||||
|
# Critical-state mode
|
||||||
|
if critical_base:
|
||||||
|
cmd.extend([
|
||||||
|
"--critical_state_analysis_dir", str(run_dir),
|
||||||
|
"--run_dir", str(critical_base) # base run dir (already completed)
|
||||||
|
])
|
||||||
|
if resume_from_phase:
|
||||||
|
cmd.extend(["--resume_from_phase", resume_from_phase])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_one(args) -> _RunInfo:
|
||||||
|
"""
|
||||||
|
Worker executed inside a ProcessPool; runs one game via subprocess.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
idx,
|
||||||
|
lm_game_script,
|
||||||
|
base_cli,
|
||||||
|
run_dir,
|
||||||
|
seed,
|
||||||
|
critical_base,
|
||||||
|
resume_phase,
|
||||||
|
) = args
|
||||||
|
|
||||||
|
cmd = _build_cmd(
|
||||||
|
lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase
|
||||||
|
)
|
||||||
|
start = time.perf_counter()
|
||||||
|
log.debug("Run %05d: CMD = %s", idx, " ".join(cmd))
|
||||||
|
|
||||||
|
# Write out full command for traceability
|
||||||
|
(run_dir / "command.txt").write_text(" ".join(cmd))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
(run_dir / "console.log").write_text(result.stdout)
|
||||||
|
rc = result.returncode
|
||||||
|
except Exception as exc: # noqa: broad-except
|
||||||
|
(run_dir / "console.log").write_text(f"Exception launching run:\n{exc}\n")
|
||||||
|
rc = 1
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
return _RunInfo(idx, run_dir, seed, " ".join(cmd), rc, elapsed)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_analysis_fns(module_names: Iterable[str]):
|
||||||
|
"""
|
||||||
|
Dynamically import analysis modules.
|
||||||
|
Each module must expose `run(experiment_dir: Path, cfg: dict)`.
|
||||||
|
"""
|
||||||
|
for name in module_names:
|
||||||
|
mod_name = f"experiment_runner.analysis.{name.strip()}"
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
log.warning("Analysis module %s not found (%s) – skipping", mod_name, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(mod, "run"):
|
||||||
|
log.warning("%s has no `run()` function – skipping", mod_name)
|
||||||
|
continue
|
||||||
|
yield mod.run
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Main driver #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
def main() -> None:
|
||||||
|
exp_args, leftover_cli, game_args = _parse_cli()
|
||||||
|
|
||||||
|
exp_dir: Path = exp_args.experiment_dir.expanduser().resolve()
|
||||||
|
if exp_dir.exists():
|
||||||
|
log.info("Appending to existing experiment: %s", exp_dir)
|
||||||
|
exp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Persist experiment-level config
|
||||||
|
cfg_path = exp_dir / "config.json"
|
||||||
|
if not cfg_path.exists(): # ← new guard
|
||||||
|
with cfg_path.open("w", encoding="utf-8") as fh:
|
||||||
|
json.dump(
|
||||||
|
{"experiment": vars(exp_args),
|
||||||
|
"lm_game": vars(game_args),
|
||||||
|
"forwarded_cli": leftover_cli},
|
||||||
|
fh, indent=2, default=str,
|
||||||
|
)
|
||||||
|
log.info("Config saved to %s", cfg_path)
|
||||||
|
else:
|
||||||
|
log.info("Config already exists – leaving unchanged")
|
||||||
|
|
||||||
|
|
||||||
|
log.info("Config saved to %s", cfg_path)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Launch games #
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
lm_game_script = Path(__file__).parent / "lm_game.py"
|
||||||
|
if not lm_game_script.exists():
|
||||||
|
log.error("lm_game.py not found at %s – abort", lm_game_script)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
run_args = []
|
||||||
|
for i in range(exp_args.iterations):
|
||||||
|
run_dir = _mk_run_dir(exp_dir, i)
|
||||||
|
seed = exp_args.seed_base + i
|
||||||
|
_dump_seed(seed, run_dir)
|
||||||
|
|
||||||
|
run_args.append(
|
||||||
|
(
|
||||||
|
i, lm_game_script, leftover_cli, run_dir, seed,
|
||||||
|
exp_args.critical_state_base_run,
|
||||||
|
game_args.resume_from_phase,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Launching %d runs (max %d parallel, critical_state=%s)",
|
||||||
|
exp_args.iterations,
|
||||||
|
exp_args.parallel,
|
||||||
|
bool(exp_args.critical_state_base_run),
|
||||||
|
)
|
||||||
|
|
||||||
|
runs_meta: list[_RunInfo] = []
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(
|
||||||
|
max_workers=exp_args.parallel,
|
||||||
|
mp_context=mp.get_context("spawn"),
|
||||||
|
) as pool:
|
||||||
|
for res in pool.map(_launch_one, run_args):
|
||||||
|
runs_meta.append(res)
|
||||||
|
status = "OK" if res.returncode == 0 else f"RC={res.returncode}"
|
||||||
|
log.info(
|
||||||
|
"run_%05d finished in %.1fs %s", res.index, res.elapsed_s, status
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist per-run status summary
|
||||||
|
summary_path = exp_dir / "runs_summary.json"
|
||||||
|
with open(summary_path, "w", encoding="utf-8") as fh:
|
||||||
|
json.dump([res._asdict() for res in runs_meta], fh, indent=2, default=str)
|
||||||
|
log.info("Run summary written → %s", summary_path)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Post-analysis pipeline #
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
mods = list(_load_analysis_fns(exp_args.analysis_modules.split(",")))
|
||||||
|
if not mods:
|
||||||
|
log.warning("No analysis modules loaded – done.")
|
||||||
|
return
|
||||||
|
|
||||||
|
analysis_root = exp_dir / "analysis"
|
||||||
|
if analysis_root.exists():
|
||||||
|
shutil.rmtree(analysis_root) # ← wipes old outputs
|
||||||
|
analysis_root.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Collect common context
|
||||||
|
ctx: dict = {
|
||||||
|
"exp_dir": str(exp_dir),
|
||||||
|
"runs": [str(r.run_dir) for r in runs_meta],
|
||||||
|
"critical_state_base": str(exp_args.critical_state_base_run or ""),
|
||||||
|
"resume_from_phase": game_args.resume_from_phase,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fn in mods:
|
||||||
|
name = fn.__module__.rsplit(".", 1)[-1]
|
||||||
|
log.info("Running analysis module: %s", name)
|
||||||
|
try:
|
||||||
|
fn(exp_dir, ctx)
|
||||||
|
log.info("✓ %s complete", name)
|
||||||
|
except Exception as exc: # noqa: broad-except
|
||||||
|
log.exception("Analysis module %s failed: %s", name, exc)
|
||||||
|
|
||||||
|
log.info("Experiment finished – artefacts in %s", exp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
experiment_runner/__init__.py
Normal file
0
experiment_runner/__init__.py
Normal file
0
experiment_runner/analysis/__init__.py
Normal file
0
experiment_runner/analysis/__init__.py
Normal file
60
experiment_runner/analysis/critical_state.py
Normal file
60
experiment_runner/analysis/critical_state.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""
|
||||||
|
Extracts the board state *before* and *after* a critical phase
|
||||||
|
for every critical-analysis run produced by experiment_runner.
|
||||||
|
|
||||||
|
Each run must have:
|
||||||
|
• lmvsgame.json – containing a phase named ctx["resume_from_phase"]
|
||||||
|
|
||||||
|
Outputs live in analysis/critical_state/<run_name>_{before|after}.json
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _phase_by_name(phases: List[dict], name: str) -> dict | None:
|
||||||
|
for ph in phases:
|
||||||
|
if ph["state"]["name"] == name:
|
||||||
|
return ph
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run(exp_dir: Path, ctx: Dict) -> None:
|
||||||
|
resume_phase = ctx.get("resume_from_phase")
|
||||||
|
if not resume_phase:
|
||||||
|
log.info("critical_state: --resume_from_phase not supplied – skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
out_dir = exp_dir / "analysis" / "critical_state"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for run_dir in (exp_dir / "runs").iterdir():
|
||||||
|
game_json = run_dir / "lmvsgame.json"
|
||||||
|
if not game_json.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
with game_json.open("r") as fh:
|
||||||
|
game = json.load(fh)
|
||||||
|
phases = game.get("phases", [])
|
||||||
|
if not phases:
|
||||||
|
continue
|
||||||
|
|
||||||
|
before = _phase_by_name(phases, resume_phase)
|
||||||
|
after = phases[-1] if phases else None
|
||||||
|
if before is None or after is None:
|
||||||
|
log.warning("Run %s missing expected phases – skipped", run_dir.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
(out_dir / f"{run_dir.name}_before.json").write_text(
|
||||||
|
json.dumps(before, indent=2)
|
||||||
|
)
|
||||||
|
(out_dir / f"{run_dir.name}_after.json").write_text(
|
||||||
|
json.dumps(after, indent=2)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("critical_state: snapshots written to %s", out_dir)
|
||||||
177
experiment_runner/analysis/summary.py
Normal file
177
experiment_runner/analysis/summary.py
Normal file
|
|
@ -0,0 +1,177 @@
|
||||||
|
"""
|
||||||
|
Aggregates results across all runs and writes:
|
||||||
|
|
||||||
|
• analysis/aggregated_results.csv
|
||||||
|
• analysis/score_summary_by_power.csv
|
||||||
|
• analysis/results_summary.png (if matplotlib available)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Helpers copied (and lightly cleaned) from run_games.py #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
def _extract_diplomacy_results(game_json: Path) -> Dict[str, Dict[str, Any]]:
|
||||||
|
with game_json.open("r", encoding="utf-8") as fh:
|
||||||
|
gd = json.load(fh)
|
||||||
|
|
||||||
|
phases = gd.get("phases", [])
|
||||||
|
if not phases:
|
||||||
|
raise ValueError("no phases")
|
||||||
|
|
||||||
|
first_state = phases[0]["state"]
|
||||||
|
last_state = phases[-1]["state"]
|
||||||
|
|
||||||
|
powers = (
|
||||||
|
list(first_state.get("homes", {}).keys())
|
||||||
|
or list(first_state.get("centers", {}).keys())
|
||||||
|
)
|
||||||
|
if not powers:
|
||||||
|
raise ValueError("cannot determine powers")
|
||||||
|
|
||||||
|
scs_to_win = 18
|
||||||
|
solo_winner = next(
|
||||||
|
(p for p, sc in last_state["centers"].items() if len(sc) >= scs_to_win), None
|
||||||
|
)
|
||||||
|
|
||||||
|
results: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for p in powers:
|
||||||
|
sc_count = len(last_state["centers"].get(p, []))
|
||||||
|
units = len(last_state["units"].get(p, []))
|
||||||
|
|
||||||
|
if solo_winner:
|
||||||
|
if p == solo_winner:
|
||||||
|
outcome, cat = f"Won solo", "Solo Win"
|
||||||
|
else:
|
||||||
|
outcome, cat = f"Lost to {solo_winner}", "Loss"
|
||||||
|
elif sc_count == 0 and units == 0:
|
||||||
|
outcome, cat = "Eliminated", "Eliminated"
|
||||||
|
else:
|
||||||
|
outcome, cat = "Ongoing/Draw", "Ongoing/Abandoned/Draw"
|
||||||
|
|
||||||
|
results[p] = {
|
||||||
|
"OutcomeCategory": cat,
|
||||||
|
"StatusDetail": outcome,
|
||||||
|
"SupplyCenters": sc_count,
|
||||||
|
"LastPhase": last_state["name"],
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# simplistic "Diplobench" scoring from previous discussion ------------------ #
|
||||||
|
def _year(name: str) -> int | None:
|
||||||
|
m = re.search(r"(\d{4})", name)
|
||||||
|
return int(m.group(1)) if m else None
|
||||||
|
|
||||||
|
|
||||||
|
def _score_game(game_json: Path) -> Dict[str, int]:
|
||||||
|
with open(game_json, "r") as fh:
|
||||||
|
game = json.load(fh)
|
||||||
|
phases = game["phases"]
|
||||||
|
if not phases:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
start = _year(phases[0]["state"]["name"])
|
||||||
|
end_year = _year(phases[-1]["state"]["name"]) or start
|
||||||
|
max_turns = (end_year - start + 1) if start is not None else len(phases)
|
||||||
|
|
||||||
|
last_state = phases[-1]["state"]
|
||||||
|
solo_winner = next(
|
||||||
|
(p for p, scs in last_state["centers"].items() if len(scs) >= 18), None
|
||||||
|
)
|
||||||
|
|
||||||
|
elim_turn: Dict[str, int | None] = {}
|
||||||
|
for p in last_state["centers"].keys():
|
||||||
|
e_turn = None
|
||||||
|
for idx, ph in enumerate(phases):
|
||||||
|
if not ph["state"]["centers"].get(p, []):
|
||||||
|
yr = _year(ph["state"]["name"]) or 0
|
||||||
|
e_turn = (yr - start + 1) if start is not None else idx + 1
|
||||||
|
break
|
||||||
|
elim_turn[p] = e_turn
|
||||||
|
|
||||||
|
scores: Dict[str, int] = {}
|
||||||
|
for p, scs in last_state["centers"].items():
|
||||||
|
if p == solo_winner:
|
||||||
|
win_turn = (end_year - start + 1) if start is not None else len(phases)
|
||||||
|
scores[p] = max_turns + 17 + (max_turns - win_turn)
|
||||||
|
elif solo_winner:
|
||||||
|
# losers in a solo game
|
||||||
|
yr_win = _year(phases[-1]["state"]["name"]) or end_year
|
||||||
|
turn_win = (yr_win - start + 1) if start is not None else len(phases)
|
||||||
|
scores[p] = turn_win
|
||||||
|
else:
|
||||||
|
if elim_turn[p] is None:
|
||||||
|
scores[p] = max_turns + len(scs)
|
||||||
|
else:
|
||||||
|
scores[p] = elim_turn[p]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Public entry point #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
def run(exp_dir: Path, ctx: dict): # pylint: disable=unused-argument
|
||||||
|
analysis_dir = exp_dir / "analysis"
|
||||||
|
analysis_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
rows: List[Dict[str, Any]] = []
|
||||||
|
for run_dir in (exp_dir / "runs").iterdir():
|
||||||
|
game_json = run_dir / "lmvsgame.json"
|
||||||
|
if not game_json.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
gid = run_dir.name
|
||||||
|
try:
|
||||||
|
res = _extract_diplomacy_results(game_json)
|
||||||
|
except Exception as e: # noqa: broad-except
|
||||||
|
log.warning("Could not parse %s (%s)", game_json, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores = _score_game(game_json)
|
||||||
|
for pwr, info in res.items():
|
||||||
|
out = {"GameID": gid, "Power": pwr, **info, "Score": scores.get(pwr, None)}
|
||||||
|
rows.append(out)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
log.warning("summary: no parsable runs found")
|
||||||
|
return
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
out_csv = analysis_dir / "aggregated_results.csv"
|
||||||
|
df.to_csv(out_csv, index=False)
|
||||||
|
|
||||||
|
summary = (
|
||||||
|
df.groupby("Power")["Score"]
|
||||||
|
.agg(["mean", "median", "count"])
|
||||||
|
.reset_index()
|
||||||
|
.rename(columns={"count": "n"})
|
||||||
|
)
|
||||||
|
summary.to_csv(analysis_dir / "score_summary_by_power.csv", index=False)
|
||||||
|
|
||||||
|
log.info("summary: wrote %s rows to %s", len(df), out_csv)
|
||||||
|
|
||||||
|
# Optional charts
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
sns.set_style("whitegrid")
|
||||||
|
plt.figure(figsize=(10, 7))
|
||||||
|
sns.boxplot(x="Power", y="SupplyCenters", data=df, palette="pastel")
|
||||||
|
plt.title("Supply-center distribution")
|
||||||
|
plt.savefig(analysis_dir / "results_summary.png", dpi=150)
|
||||||
|
plt.close()
|
||||||
|
log.info("summary: chart saved")
|
||||||
|
except Exception as e: # noqa: broad-except
|
||||||
|
log.debug("Chart generation skipped (%s)", e)
|
||||||
35
lm_game.py
35
lm_game.py
|
|
@ -55,6 +55,12 @@ def parse_arguments():
|
||||||
default="",
|
default="",
|
||||||
help="Directory for results. If it exists, the game resumes. If not, it's created. Defaults to a new timestamped directory.",
|
help="Directory for results. If it exists, the game resumes. If not, it's created. Defaults to a new timestamped directory.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", # alias for back compatibility
|
||||||
|
dest="run_dir", # write to the same variable as --run_dir
|
||||||
|
type=str,
|
||||||
|
help=argparse.SUPPRESS # hides it from `--help`
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--critical_state_analysis_dir",
|
"--critical_state_analysis_dir",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -105,12 +111,24 @@ def parse_arguments():
|
||||||
default=16000,
|
default=16000,
|
||||||
help="Maximum number of new tokens to generate per LLM call (default: 16000)."
|
help="Maximum number of new tokens to generate per LLM call (default: 16000)."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed_base",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="RNG seed placeholder for compatibility with experiment_runner. Currently unused."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_tokens_per_model",
|
"--max_tokens_per_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
|
help="Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompts_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the directory containing prompt files. Defaults to the packaged prompts directory."
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -179,8 +197,17 @@ async def main():
|
||||||
if is_resuming:
|
if is_resuming:
|
||||||
try:
|
try:
|
||||||
# When resuming, we load the state and also the config from the last saved phase.
|
# When resuming, we load the state and also the config from the last saved phase.
|
||||||
# We will continue to use THAT config, ignoring the current script's args.
|
game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, args.resume_from_phase)
|
||||||
game, agents, game_history = load_game_state(run_dir, game_file_name, args.resume_from_phase)
|
if loaded_run_config:
|
||||||
|
# Use the saved config, but allow current CLI args to override control-flow parameters
|
||||||
|
run_config = loaded_run_config
|
||||||
|
run_config.run_dir = args.run_dir
|
||||||
|
run_config.critical_state_analysis_dir = args.critical_state_analysis_dir
|
||||||
|
run_config.resume_from_phase = args.resume_from_phase
|
||||||
|
run_config.end_at_phase = args.end_at_phase
|
||||||
|
# If prompts_dir is specified now, it overrides the saved one.
|
||||||
|
if args.prompts_dir is not None:
|
||||||
|
run_config.prompts_dir = args.prompts_dir
|
||||||
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
|
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
|
||||||
except (FileNotFoundError, ValueError) as e:
|
except (FileNotFoundError, ValueError) as e:
|
||||||
logger.error(f"Could not resume game: {e}. Starting a new game instead.")
|
logger.error(f"Could not resume game: {e}. Starting a new game instead.")
|
||||||
|
|
@ -204,7 +231,7 @@ async def main():
|
||||||
if year_int > run_config.max_year:
|
if year_int > run_config.max_year:
|
||||||
logger.info(f"Reached max year {run_config.max_year}, stopping simulation.")
|
logger.info(f"Reached max year {run_config.max_year}, stopping simulation.")
|
||||||
break
|
break
|
||||||
if run_config.end_at_phase and game.phases and game.phases[-1].name == run_config.end_at_phase:
|
if run_config.end_at_phase and current_phase == run_config.end_at_phase:
|
||||||
logger.info(f"Reached end phase {run_config.end_at_phase}, stopping simulation.")
|
logger.info(f"Reached end phase {run_config.end_at_phase}, stopping simulation.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -289,7 +316,7 @@ async def main():
|
||||||
# Diary Consolidation
|
# Diary Consolidation
|
||||||
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
||||||
consolidation_tasks = [
|
consolidation_tasks = [
|
||||||
run_diary_consolidation(agent, game, llm_log_file_path)
|
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
|
||||||
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
|
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
|
||||||
]
|
]
|
||||||
if consolidation_tasks:
|
if consolidation_tasks:
|
||||||
|
|
|
||||||
|
|
@ -17,3 +17,5 @@ anthropic
|
||||||
google-genai
|
google-genai
|
||||||
json-repair
|
json-repair
|
||||||
together
|
together
|
||||||
|
matplotlib
|
||||||
|
seaborn
|
||||||
5
run.sh
5
run.sh
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
python3 lm_game.py \
|
python3 lm_game.py \
|
||||||
--max_year 1901 \
|
--max_year 1901 \
|
||||||
--num_negotiation_rounds 0 \
|
--num_negotiation_rounds 1 \
|
||||||
--models "openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17" \
|
--models "openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17, openrouter-google/gemini-2.5-flash-lite-preview-06-17" \
|
||||||
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000
|
--max_tokens_per_model 16000,16000,16000,16000,16000,16000,16000 \
|
||||||
|
--prompts_dir "ai_diplomacy/prompts"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue