mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
fix to respect model ids passed as args when resuming
This commit is contained in:
parent
4fc1f370be
commit
a241e34496
2 changed files with 64 additions and 28 deletions
|
|
@ -31,24 +31,40 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|||
"private_diary": agent.private_diary,
|
||||
}
|
||||
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None) -> DiplomacyAgent:
|
||||
"""Recreates an agent object from a dictionary."""
|
||||
client = load_model_client(agent_data["model_id"], prompts_dir=prompts_dir)
|
||||
client.max_tokens = agent_data.get("max_tokens", 16000) # Default for older saves
|
||||
def deserialize_agent(
|
||||
agent_data: dict,
|
||||
prompts_dir: Optional[str] = None,
|
||||
*,
|
||||
override_model_id: Optional[str] = None
|
||||
) -> DiplomacyAgent:
|
||||
"""
|
||||
Recreates an agent object from a dictionary.
|
||||
|
||||
If *override_model_id* is provided (e.g. because the CLI argument
|
||||
``--models`` was used when resuming a game), that model is loaded
|
||||
instead of the one stored in the save file.
|
||||
"""
|
||||
model_id = override_model_id or agent_data["model_id"]
|
||||
client = load_model_client(model_id, prompts_dir=prompts_dir)
|
||||
|
||||
# Keep the original or fallback token limit exactly as before.
|
||||
client.max_tokens = agent_data.get("max_tokens", 16000)
|
||||
|
||||
agent = DiplomacyAgent(
|
||||
power_name=agent_data["power_name"],
|
||||
client=client,
|
||||
initial_goals=agent_data.get("goals", []),
|
||||
initial_relationships=agent_data.get("relationships", None),
|
||||
prompts_dir=prompts_dir
|
||||
prompts_dir=prompts_dir,
|
||||
)
|
||||
# Restore the diary.
|
||||
|
||||
# Restore diary state
|
||||
agent.full_private_diary = agent_data.get("full_private_diary", [])
|
||||
agent.private_diary = agent_data.get("private_diary", [])
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
# --- State Management ---
|
||||
|
||||
# game_logic.py
|
||||
|
|
@ -240,16 +256,45 @@ def load_game_state(
|
|||
else:
|
||||
last_phase_data = saved_game_data['phases'][-2]
|
||||
|
||||
# Rebuild agents
|
||||
# -------------------- Rebuild agents -------------------- #
|
||||
agents = {}
|
||||
if 'state_agents' in last_phase_data:
|
||||
|
||||
# Build a power→model map from the CLI argument --models, if present.
|
||||
power_model_map: Dict[str, str] = {}
|
||||
if run_config and getattr(run_config, "models", None):
|
||||
provided = [m.strip() for m in run_config.models.split(",")]
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
if len(provided) == len(powers_order):
|
||||
power_model_map = dict(zip(powers_order, provided))
|
||||
elif len(provided) == 1:
|
||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --models argument: expected 1 or {len(powers_order)} "
|
||||
f"items, got {len(provided)}."
|
||||
)
|
||||
|
||||
if "state_agents" in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, 'prompts_dir') else None
|
||||
for power_name, agent_data in last_phase_data['state_agents'].items():
|
||||
agents[power_name] = deserialize_agent(agent_data, prompts_dir=prompts_dir_from_config)
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir
|
||||
if run_config and hasattr(run_config, "prompts_dir")
|
||||
else None
|
||||
)
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
agents[power_name] = deserialize_agent(
|
||||
agent_data,
|
||||
prompts_dir=prompts_dir_from_config,
|
||||
override_model_id=override_id,
|
||||
)
|
||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||
else:
|
||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||
raise ValueError(
|
||||
"Cannot resume: 'state_agents' key not found in the last phase "
|
||||
"of the save file."
|
||||
)
|
||||
|
||||
|
||||
# Rebuild GameHistory
|
||||
game_history = GameHistory()
|
||||
|
|
@ -292,6 +337,7 @@ async def initialize_new_game(
|
|||
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
||||
|
||||
# Handle power model mapping
|
||||
|
||||
if args.models:
|
||||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) == len(powers_order):
|
||||
|
|
|
|||
14
lm_game.py
14
lm_game.py
|
|
@ -262,19 +262,9 @@ async def main():
|
|||
|
||||
if is_resuming:
|
||||
try:
|
||||
# When resuming, we load the state and also the config from the last saved phase.
|
||||
game, agents, game_history, loaded_run_config = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase)
|
||||
# When resuming, we always use the provided params (they will override the params used in the saved state)
|
||||
game, agents, game_history, _ = load_game_state(run_dir, game_file_name, run_config, args.resume_from_phase)
|
||||
|
||||
if loaded_run_config:
|
||||
# Use the saved config, but allow current CLI args to override control-flow parameters
|
||||
run_config = loaded_run_config
|
||||
run_config.run_dir = args.run_dir
|
||||
run_config.critical_state_analysis_dir = args.critical_state_analysis_dir
|
||||
run_config.resume_from_phase = args.resume_from_phase
|
||||
run_config.end_at_phase = args.end_at_phase
|
||||
# If prompts_dir is specified now, it overrides the saved one.
|
||||
if args.prompts_dir is not None:
|
||||
run_config.prompts_dir = args.prompts_dir
|
||||
logger.info(f"Successfully resumed game from phase: {game.get_current_phase()}.")
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.error(f"Could not resume game: {e}. Starting a new game instead.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue