fix to respect model ids passed as args when resuming

This commit is contained in:
sam-paech 2025-07-03 09:59:47 +10:00
parent 4fc1f370be
commit a241e34496
2 changed files with 64 additions and 28 deletions

View file

@ -31,24 +31,40 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary,
}
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent:
"""Recreates an agent object from a dictionary."""
client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir)
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
def deserialize_agent(
agent_data: dict,
prompts_dir: Optional[str] = None,
*,
override_model_id: Optional[str] = None
) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
If *override_model_id* is provided (e.g. because the CLI argument
``--models`` was used when resuming a game), that model is loaded
instead of the one stored in the save file.
"""
model_id = override_model_id or agent_data["model_id"]
client = load_model_client(model_id, prompts_dir=prompts_dir)
# Keep the original or fallback token limit exactly as before.
client.max_tokens = agent_data.get("max_tokens", 16000)
agent = DiplomacyAgent(
power_name=agent_data["power_name"],
client=client,
initial_goals=agent_data.get("goals", []),
initial_relationships=agent_data.get("relationships", None),
prompts_dir=prompts_dir
prompts_dir=prompts_dir,
)
# Restore the diary.
# Restore diary state
agent.full_private_diary = agent_data.get("full_private_diary", [])
agent.private_diary = agent_data.get("private_diary", [])
return agent
# --- State Management ---
# game_logic.py
@ -240,16 +256,45 @@ def load_game_state(
else:
last_phase_data = saved_game_data['phases'][-2]
# Rebuild agents
# -------------------- Rebuild agents -------------------- #
agents = {}
if 'state_agents' in last_phase_data:
# Build a power→model map from the CLI argument --models, if present.
power_model_map: Dict[str, str] = {}
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
raise ValueError(
f"Invalid --models argument: expected 1 or {len(powers_order)} "
f"items, got {len(provided)}."
)
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None
for power_name, agent_data in last_phase_data['state_agents'].items():
agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config)
prompts_dir_from_config = (
run_config.prompts_dir
if run_config and hasattr(run_config, "prompts_dir")
else None
)
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
raise ValueError(
"Cannot resume: 'state_agents' key not found in the last phase "
"of the save file."
)
# Rebuild GameHistory
game_history = GameHistory()
@ -292,8 +337,9 @@ async def initialize_new_game(
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
# Handle power model mapping
if args.models:
provided_models = [name.strip() for name in args.models.split(",")]
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:

View file

@ -262,19 +262,9 @@ async def main():
if is_resuming:
try:
# When resuming, we load the state and also the config from the last saved phase.
game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase)
if loaded_run_config:
# Use the saved config, but allow current CLI args to override control-flow parameters
run_config = loaded_run_config
run_config.run_dir = args.run_dir
run_config.critical_state_analysis_dir = args.critical_state_analysis_dir
run_config.resume_from_phase = args.resume_from_phase
run_config.end_at_phase = args.end_at_phase
# If prompts_dir is specified now, it overrides the saved one.
if args.prompts_dir is not None:
run_config.prompts_dir = args.prompts_dir
# When resuming, we always use the provided params (they will override the params used in the saved state)
game, agents, game_history, _ = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase)
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
except (FileNotFoundError, ValueError) as e:
logger.error(f"Could not resume game: {e}. Starting a new game instead.")