diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index 3790240..2c94258 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -999,6 +999,10 @@ class DiplomacyAgent: # Extract year from the phase name (e.g., "S1901M" -> "1901") current_year = last_phase_name[1:5] if len(last_phase_name) >= 5 else "unknown" + # Format current goals and relationships for the prompt + current_goals_str = json.dumps(self.goals, indent=2) if self.goals else "[]" + current_relationships_str = json.dumps(self.relationships, indent=2) if self.relationships else "{}" + prompt = prompt_template.format( power_name=power_name, current_year=current_year, @@ -1006,6 +1010,8 @@ class DiplomacyAgent: board_state_str=context, phase_summary=last_phase_summary, # Use provided phase_summary other_powers=str(other_powers), # Pass as string representation + current_goals=current_goals_str, + current_relationships=current_relationships_str, ) logger.debug(f"[{power_name}] State update prompt:\n{prompt}") diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index df0997d..3648b96 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -172,11 +172,15 @@ def load_game_state( game_file_name: str, run_config, resume_from_phase: Optional[str] = None, -) -> Tuple["Game", Dict[str, "DiplomacyAgent"], "GameHistory", Optional[Any]]: +) -> Tuple["Game", Dict[str, "DiplomacyAgent"], "GameHistory", Optional[Any], Optional[Dict[str, Dict[str, str]]]]: """ Load and fully re-hydrate the game, agents and GameHistory – including `orders_by_power`, `results_by_power`, `submitted_orders_by_power`, and per-power `phase_summaries`. + + Returns a 5-tuple: (game, agents, game_history, run_config, saved_relationships) + where saved_relationships is extracted from agent_relationships field (old format) + when state_agents is not available. """ from collections import defaultdict # local to avoid new global import @@ -235,24 +239,37 @@ def load_game_state( else: raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.") + # Extract relationships from previous phase for agent initialization + saved_relationships = None if saved_game_data.get("phases"): last_phase_data = saved_game_data["phases"][-2] if len(saved_game_data["phases"]) > 1 else {} if "state_agents" not in last_phase_data: - raise ValueError("Cannot resume: 'state_agents' key missing in last completed phase.") - - for power_name, agent_data in last_phase_data["state_agents"].items(): - override_id = power_model_map.get(power_name) - prompts_dir_from_config = ( - run_config.prompts_dir_map.get(power_name) - if getattr(run_config, "prompts_dir_map", None) - else run_config.prompts_dir - ) - agents[power_name] = deserialize_agent( - agent_data, - prompts_dir=prompts_dir_from_config, - override_model_id=override_id, - override_max_tokens=model_max_tokens.get(power_name), - ) + # Try to load relationships from agent_relationships field (older format) + if "agent_relationships" in last_phase_data: + saved_relationships = last_phase_data["agent_relationships"] + logger.info( + "Loaded agent_relationships from previous phase. " + "Agents will be initialized fresh but with historical relationships." + ) + else: + logger.warning( + "Cannot resume agents: 'state_agents' key missing in last completed phase. " + "Agents will be initialized fresh (losing prior context/relationships)." + ) + else: + for power_name, agent_data in last_phase_data["state_agents"].items(): + override_id = power_model_map.get(power_name) + prompts_dir_from_config = ( + run_config.prompts_dir_map.get(power_name) + if getattr(run_config, "prompts_dir_map", None) + else run_config.prompts_dir + ) + agents[power_name] = deserialize_agent( + agent_data, + prompts_dir=prompts_dir_from_config, + override_model_id=override_id, + override_max_tokens=model_max_tokens.get(power_name), + ) # --- Rebuild GameHistory -------------------------------------------------- game_history = GameHistory() @@ -297,7 +314,7 @@ def load_game_state( submitted[pwr].append(order_str) ph_obj.submitted_orders_by_power = submitted - return game, agents, game_history, run_config + return game, agents, game_history, run_config, saved_relationships # ai_diplomacy/game_logic.py @@ -306,8 +323,15 @@ async def initialize_new_game( game: Game, game_history: GameHistory, llm_log_file_path: str, + saved_relationships: Optional[Dict[str, Dict[str, str]]] = None, ) -> Dict[str, DiplomacyAgent]: - """Initializes agents for a new game (supports per-power prompt directories).""" + """ + Initializes agents for a new game (supports per-power prompt directories). + + Args: + saved_relationships: Optional historical relationships to restore from old game format. + Dict mapping power_name -> {other_power: relationship_status} + """ powers_order = sorted(list(ALL_POWERS)) @@ -357,9 +381,17 @@ async def initialize_new_game( try: client = load_model_client(model_id, prompts_dir=prompts_dir_for_power) client.max_tokens = model_max_tokens[power_name] + + # Extract historical relationships if available + initial_relationships = None + if saved_relationships and power_name in saved_relationships: + initial_relationships = saved_relationships[power_name] + logger.info(f"[{power_name}] Restoring historical relationships: {initial_relationships}") + agent = DiplomacyAgent( power_name=power_name, client=client, + initial_relationships=initial_relationships, prompts_dir=prompts_dir_for_power, ) agents[power_name] = agent diff --git a/ai_diplomacy/initialization.py b/ai_diplomacy/initialization.py index 637458b..ad6f509 100644 --- a/ai_diplomacy/initialization.py +++ b/ai_diplomacy/initialization.py @@ -59,6 +59,8 @@ async def initialize_agent_state_ext( formatted_diary = agent.format_private_diary_for_prompt() + # Pass agent's current relationships to provide historical context to the LLM + # This ensures the LLM sees any pre-existing relationship state when setting initial goals context = build_context_prompt( game=game, board_state=board_state, @@ -66,7 +68,7 @@ async def initialize_agent_state_ext( possible_orders=None, # Don't include orders for initial state setup game_history=game_history, agent_goals=None, - agent_relationships=None, + agent_relationships=agent.relationships, # Pass agent's relationships for LLM context agent_private_diary=formatted_diary, prompts_dir=effective_prompts_dir, ) diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index d0d9180..a51ea47 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -63,7 +63,8 @@ def assign_models_to_powers() -> Dict[str, str]: Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini, claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219 gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25, - deepseek-chat, deepseek-reasoner + deepseek-chat, deepseek-reasoner, + kimi-k2-0905-preview, openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free, openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free, openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20 @@ -72,13 +73,13 @@ def assign_models_to_powers() -> Dict[str, str]: # POWER MODELS return { - "AUSTRIA": "o4-mini", - "ENGLAND": "o3", - "FRANCE": "gpt-5-reasoning-alpha-2025-07-19", - "GERMANY": "gpt-4.1", - "ITALY": "o4-mini", - "RUSSIA": "gpt-5-reasoning-alpha-2025-07-19", - "TURKEY": "o4-mini", + "AUSTRIA": "gpt-5-mini", + "ENGLAND": "gpt-5-mini", + "FRANCE": "gpt-5-mini", + "GERMANY": "gpt-5-mini", + "ITALY": "gemini-2.5-flash", + "RUSSIA": "gemini-2.5-flash", + "TURKEY": "gemini-2.5-flash", } # TEST MODELS diff --git a/experiment_runner.py b/experiment_runner.py index 62aa880..f022206 100644 --- a/experiment_runner.py +++ b/experiment_runner.py @@ -154,7 +154,7 @@ def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: p.add_argument( "--num_negotiation_rounds", type=int, - default=0, + default=3, help="Number of negotiation rounds per phase.", ) p.add_argument( diff --git a/lm_game.py b/lm_game.py index 7c0bcb2..83795e6 100644 --- a/lm_game.py +++ b/lm_game.py @@ -100,7 +100,7 @@ def parse_arguments(): parser.add_argument( "--num_negotiation_rounds", type=int, - default=0, + default=3, help="Number of negotiation rounds per phase.", ) parser.add_argument( @@ -311,9 +311,14 @@ async def main(): if is_resuming: try: # When resuming, we always use the provided params (they will override the params used in the saved state) - game, agents, game_history, _ = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase) + game, agents, game_history, _, saved_relationships = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase) logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.") + + # If agents is empty (state_agents missing from old game format), initialize fresh agents + if not agents: + logger.warning("No agents loaded from game state. Initializing fresh agents for all powers.") + agents = await initialize_new_game(run_config, game, game_history, llm_log_file_path, saved_relationships=saved_relationships) except (FileNotFoundError, ValueError) as e: logger.error(f"Could not resume game: {e}. Starting a new game instead.") is_resuming = False # Fallback to new game