#!/usr/bin/env python3 """ Experiment orchestration for Diplomacy self-play. Launches many `lm_game` runs in parallel, captures their artefacts, and executes a pluggable post-analysis pipeline. Run `python experiment_runner.py --help` for CLI details. """ from __future__ import annotations import argparse import collections import concurrent.futures import importlib import json import logging import math import os import shutil import subprocess import sys import textwrap import time import multiprocessing as mp from datetime import datetime from pathlib import Path from types import SimpleNamespace from typing import Iterable, List # --------------------------------------------------------------------------- # # Logging # # --------------------------------------------------------------------------- # LOG_FMT = "%(asctime)s [%(levelname)s] %(name)s - %(message)s" logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%H:%M:%S") log = logging.getLogger("experiment_runner") # ──────────────────────────────────────────────────────────────────────────── # Flag definitions – full, un-shortened help strings # # ──────────────────────────────────────────────────────────────────────────── def _add_experiment_flags(p: argparse.ArgumentParser) -> None: p.add_argument( "--experiment_dir", type=Path, required=True, help=( "Directory that will hold all experiment artefacts. " "A 'runs/' sub-folder is created for individual game runs and an " "'analysis/' folder for aggregated outputs. Must be writable." ), ) p.add_argument( "--iterations", type=int, default=1, help=( "Number of lm_game instances to launch for this experiment. " "Each instance gets its own sub-directory under runs/." ), ) p.add_argument( "--parallel", type=int, default=1, help=( "Maximum number of game instances to run concurrently. " "Uses a ProcessPoolExecutor under the hood." ), ) p.add_argument( "--analysis_modules", type=str, default="summary", help=( "Comma-separated list of analysis module names to execute after all " "runs finish. Modules are imported from " "'experiment_runner.analysis.' and must expose " "run(experiment_dir: Path, ctx: dict)." ), ) p.add_argument( "--critical_state_base_run", type=Path, default=None, help=( "Path to an *existing* run directory produced by a previous lm_game " "execution. When supplied, every iteration resumes from that " "snapshot using lm_game's --critical_state_analysis_dir mechanism." ), ) p.add_argument( "--seed_base", type=int, default=42, help=( "Base RNG seed. Run i will receive seed = seed_base + i. " "Forwarded to lm_game via its --seed flag (you must have added that " "flag to lm_game for deterministic behaviour)." ), ) def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: # ---- all flags copied verbatim from lm_game.parse_arguments() ---- p.add_argument( "--resume_from_phase", type=str, default="", help=( "Phase to resume from (e.g., 'S1902M'). Requires --run_dir. " "IMPORTANT: This option clears any existing phase results ahead of " "& including the specified resume phase." ), ) p.add_argument( "--end_at_phase", type=str, default="", help="Phase to end the simulation after (e.g., 'F1905M').", ) p.add_argument( "--max_year", type=int, default=1910, # Increased default in lm_game help="Maximum year to simulate. The game will stop once this year is reached.", ) p.add_argument( "--num_negotiation_rounds", type=int, default=0, help="Number of negotiation rounds per phase.", ) p.add_argument( "--models", type=str, default="", help=( "Comma-separated list of model names to assign to powers in order. " "The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY." ), ) p.add_argument( "--planning_phase", action="store_true", help="Enable the planning phase for each power to set strategic directives.", ) p.add_argument( "--max_tokens", type=int, default=16000, help="Maximum number of new tokens to generate per LLM call (default: 16000).", ) p.add_argument( "--max_tokens_per_model", type=str, default="", help=( "Comma-separated list of 7 token limits (in order: AUSTRIA, ENGLAND, " "FRANCE, GERMANY, ITALY, RUSSIA, TURKEY). Overrides --max_tokens." ), ) p.add_argument( "--prompts_dir", type=str, default=None, help=( "Path to the directory containing prompt files. " "Defaults to the packaged prompts directory." ), ) # ──────────────────────────────────────────────────────────────────────────── # One combined parser for banner printing # # ──────────────────────────────────────────────────────────────────────────── def _build_full_parser() -> argparse.ArgumentParser: fp = argparse.ArgumentParser( prog="experiment_runner.py", formatter_class=lambda prog: argparse.RawTextHelpFormatter( prog, max_help_position=45 ), description=( "Batch-runner for Diplomacy self-play experiments. " "All lm_game flags are accepted here as-is; they are validated " "before any game runs start." ), ) _add_experiment_flags(fp) _add_lm_game_flags(fp) return fp # ──────────────────────────────────────────────────────────────────────────── # Robust parsing that always shows *full* help on error # # ──────────────────────────────────────────────────────────────────────────── def _parse_cli() -> tuple[argparse.Namespace, list[str], argparse.Namespace]: full_parser = _build_full_parser() # Show full banner when no args if len(sys.argv) == 1: full_parser.print_help(sys.stderr) sys.exit(2) # Show full banner on explicit help if any(tok in ("-h", "--help") for tok in sys.argv[1:]): full_parser.print_help(sys.stderr) sys.exit(0) # Sub-parsers for separating experiment vs game flags class _ErrParser(argparse.ArgumentParser): def error(self, msg): full_parser.print_help(sys.stderr) self.exit(2, f"{self.prog}: error: {msg}\n") exp_parser = _ErrParser(add_help=False) game_parser = _ErrParser(add_help=False) _add_experiment_flags(exp_parser) _add_lm_game_flags(game_parser) # Split argv tokens by flag ownership argv = sys.argv[1:] exp_flag_set = {opt for a in exp_parser._actions for opt in a.option_strings} exp_tok, game_tok, i = [], [], 0 while i < len(argv): tok = argv[i] if tok in exp_flag_set: exp_tok.append(tok) action = exp_parser._option_string_actions[tok] needs_val = ( action.nargs is None # default: exactly one value or (isinstance(action.nargs, int) and action.nargs > 0) or action.nargs in ("+", "*", "?") # variable-length cases ) if needs_val: exp_tok.append(argv[i + 1]) i += 2 else: # store_true / store_false i += 1 else: game_tok.append(tok) i += 1 exp_args = exp_parser.parse_args(exp_tok) game_args = game_parser.parse_args(game_tok) return exp_args, game_tok, game_args # --------------------------------------------------------------------------- # # Helpers # # --------------------------------------------------------------------------- # _RunInfo = collections.namedtuple( "_RunInfo", "index run_dir seed cmd_line returncode elapsed_s" ) def _mk_run_dir(exp_dir: Path, idx: int) -> Path: run_dir = exp_dir / "runs" / f"run_{idx:05d}" # Just ensure it exists; don't raise if it already does. run_dir.mkdir(parents=True, exist_ok=True) return run_dir def _dump_seed(seed: int, run_dir: Path) -> None: seed_file = run_dir / "seed.txt" if not seed_file.exists(): seed_file.write_text(str(seed)) def _build_cmd( lm_game_script: Path, base_cli: List[str], run_dir: Path, seed: int, critical_base: Path | None, resume_from_phase: str, ) -> List[str]: """ Returns a list suitable for subprocess.run([...]). """ cmd = [sys.executable, str(lm_game_script)] # Forward user CLI cmd.extend(base_cli) # Per-run mandatory overrides cmd.extend(["--run_dir", str(run_dir)]) cmd.extend(["--seed", str(seed)]) # you may need to add a --seed flag to lm_game # Critical-state mode if critical_base: cmd.extend([ "--critical_state_analysis_dir", str(run_dir), "--run_dir", str(critical_base) # base run dir (already completed) ]) if resume_from_phase: cmd.extend(["--resume_from_phase", resume_from_phase]) return cmd def _launch_one(args) -> _RunInfo: """ Worker executed inside a ProcessPool; runs one game via subprocess. """ ( idx, lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase, ) = args cmd = _build_cmd( lm_game_script, base_cli, run_dir, seed, critical_base, resume_phase ) start = time.perf_counter() log.debug("Run %05d: CMD = %s", idx, " ".join(cmd)) # Write out full command for traceability (run_dir / "command.txt").write_text(" ".join(cmd)) try: result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=False, ) (run_dir / "console.log").write_text(result.stdout) rc = result.returncode except Exception as exc: # noqa: broad-except (run_dir / "console.log").write_text(f"Exception launching run:\n{exc}\n") rc = 1 elapsed = time.perf_counter() - start return _RunInfo(idx, run_dir, seed, " ".join(cmd), rc, elapsed) def _load_analysis_fns(module_names: Iterable[str]): """ Dynamically import analysis modules. Each module must expose `run(experiment_dir: Path, cfg: dict)`. """ for name in module_names: mod_name = f"experiment_runner.analysis.{name.strip()}" try: mod = importlib.import_module(mod_name) except ModuleNotFoundError as e: log.warning("Analysis module %s not found (%s) – skipping", mod_name, e) continue if not hasattr(mod, "run"): log.warning("%s has no `run()` function – skipping", mod_name) continue yield mod.run # --------------------------------------------------------------------------- # # Main driver # # --------------------------------------------------------------------------- # def main() -> None: exp_args, leftover_cli, game_args = _parse_cli() exp_dir: Path = exp_args.experiment_dir.expanduser().resolve() if exp_dir.exists(): log.info("Appending to existing experiment: %s", exp_dir) exp_dir.mkdir(parents=True, exist_ok=True) # Persist experiment-level config cfg_path = exp_dir / "config.json" if not cfg_path.exists(): # ← new guard with cfg_path.open("w", encoding="utf-8") as fh: json.dump( {"experiment": vars(exp_args), "lm_game": vars(game_args), "forwarded_cli": leftover_cli}, fh, indent=2, default=str, ) log.info("Config saved to %s", cfg_path) else: log.info("Config already exists – leaving unchanged") log.info("Config saved to %s", cfg_path) # ------------------------------------------------------------------ # # Launch games # # ------------------------------------------------------------------ # lm_game_script = Path(__file__).parent / "lm_game.py" if not lm_game_script.exists(): log.error("lm_game.py not found at %s – abort", lm_game_script) sys.exit(1) run_args = [] for i in range(exp_args.iterations): run_dir = _mk_run_dir(exp_dir, i) seed = exp_args.seed_base + i _dump_seed(seed, run_dir) run_args.append( ( i, lm_game_script, leftover_cli, run_dir, seed, exp_args.critical_state_base_run, game_args.resume_from_phase, ) ) log.info( "Launching %d runs (max %d parallel, critical_state=%s)", exp_args.iterations, exp_args.parallel, bool(exp_args.critical_state_base_run), ) runs_meta: list[_RunInfo] = [] with concurrent.futures.ProcessPoolExecutor( max_workers=exp_args.parallel, mp_context=mp.get_context("spawn"), ) as pool: for res in pool.map(_launch_one, run_args): runs_meta.append(res) status = "OK" if res.returncode == 0 else f"RC={res.returncode}" log.info( "run_%05d finished in %.1fs %s", res.index, res.elapsed_s, status ) # Persist per-run status summary summary_path = exp_dir / "runs_summary.json" with open(summary_path, "w", encoding="utf-8") as fh: json.dump([res._asdict() for res in runs_meta], fh, indent=2, default=str) log.info("Run summary written → %s", summary_path) # ------------------------------------------------------------------ # # Post-analysis pipeline # # ------------------------------------------------------------------ # mods = list(_load_analysis_fns(exp_args.analysis_modules.split(","))) if not mods: log.warning("No analysis modules loaded – done.") return analysis_root = exp_dir / "analysis" if analysis_root.exists(): shutil.rmtree(analysis_root) # ← wipes old outputs analysis_root.mkdir(exist_ok=True) # Collect common context ctx: dict = { "exp_dir": str(exp_dir), "runs": [str(r.run_dir) for r in runs_meta], "critical_state_base": str(exp_args.critical_state_base_run or ""), "resume_from_phase": game_args.resume_from_phase, } for fn in mods: name = fn.__module__.rsplit(".", 1)[-1] log.info("Running analysis module: %s", name) try: fn(exp_dir, ctx) log.info("✓ %s complete", name) except Exception as exc: # noqa: broad-except log.exception("Analysis module %s failed: %s", name, exc) log.info("Experiment finished – artefacts in %s", exp_dir) if __name__ == "__main__": main()