Ruff format and check

This commit is contained in:
Tyler Marques 2025-07-03 14:27:50 -07:00
parent aadedd6512
commit ad817695ef
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 738 additions and 973 deletions

View file

@ -10,7 +10,7 @@ import ast
from ..config import config from ..config import config
# Assuming BaseModelClient is importable from clients.py in the same directory # Assuming BaseModelClient is importable from clients.py in the same directory
from .clients import BaseModelClient, load_model_client from .clients import BaseModelClient
# Import load_prompt and the new logging wrapper from utils # Import load_prompt and the new logging wrapper from utils
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
@ -336,7 +336,7 @@ class DiplomacyAgent:
) )
# If even the last resort doesn't give a dict, return empty dict # If even the last resort doesn't give a dict, return empty dict
return {} return {}
except Exception as e: except Exception:
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...") logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
return {} return {}
@ -965,7 +965,7 @@ class DiplomacyAgent:
other_powers = [p for p in game.powers if p != power_name] other_powers = [p for p in game.powers if p != power_name]
# Create a readable board state string from the board_state dict # Create a readable board state string from the board_state dict
board_state_str = f"Board State:\n" board_state_str = "Board State:\n"
for p_name, power_data in board_state.get("powers", {}).items(): for p_name, power_data in board_state.get("powers", {}).items():
# Get units and centers from the board state # Get units and centers from the board state
units = power_data.get("units", []) units = power_data.get("units", [])

View file

@ -1,12 +1,11 @@
import os import os
import json import json
from json import JSONDecodeError
import re import re
import logging import logging
import ast # For literal_eval in JSON fallback parsing import ast # For literal_eval in JSON fallback parsing
import aiohttp # For direct HTTP requests to Responses API import aiohttp # For direct HTTP requests to Responses API
from typing import List, Dict, Optional, Any, Tuple, NamedTuple from typing import List, Dict, Optional, Tuple, NamedTuple
from dotenv import load_dotenv from dotenv import load_dotenv
# Use Async versions of clients # Use Async versions of clients
@ -21,9 +20,9 @@ from together import AsyncTogether
from together.error import APIError as TogetherAPIError # For specific error handling from together.error import APIError as TogetherAPIError # For specific error handling
from ..config import config from ..config import config
from diplomacy.engine.message import GLOBAL
from .game_history import GameHistory from .game_history import GameHistory
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible # Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
# Moved formatter imports to avoid circular import - imported locally where needed # Moved formatter imports to avoid circular import - imported locally where needed
@ -36,6 +35,7 @@ logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
load_dotenv() load_dotenv()
############################################################################## ##############################################################################
# 1) Base Interface # 1) Base Interface
############################################################################## ##############################################################################
@ -118,24 +118,19 @@ class BaseModelClient:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=phase, phase=phase,
response_type='order', # Context for run_llm_and_log's own error logging response_type="order", # Context for run_llm_and_log's own error logging
temperature=0 temperature=0,
)
logger.debug(
f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
) )
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS # Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if config.USE_UNFORMATTED_PROMPTS: if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency # Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_ORDERS from .formatter import format_with_gemini_flash, FORMAT_ORDERS
# Format the natural language response into structured format # Format the natural language response into structured format
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
raw_response, raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
FORMAT_ORDERS,
power_name=power_name,
phase=phase,
log_file_path=log_file_path
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -145,9 +140,7 @@ class BaseModelClient:
move_list = self._extract_moves(formatted_response, power_name) move_list = self._extract_moves(formatted_response, power_name)
if not move_list: if not move_list:
logger.warning( logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
)
if model_error_stats is not None and self.model_name in model_error_stats: if model_error_stats is not None and self.model_name in model_error_stats:
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0) model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
model_error_stats[self.model_name]["order_decoding_errors"] += 1 model_error_stats[self.model_name]["order_decoding_errors"] += 1
@ -195,7 +188,7 @@ class BaseModelClient:
response_type="order_generation", # Specific type for CSV logging response_type="order_generation", # Specific type for CSV logging
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
raw_response=raw_response, raw_response=raw_response,
success=success_status success=success_status,
# token_usage and cost can be added later if available and if log_llm_response supports them # token_usage and cost can be added later if available and if log_llm_response supports them
) )
return parsed_orders_for_return return parsed_orders_for_return
@ -216,9 +209,7 @@ class BaseModelClient:
if not matches: if not matches:
# Some LLMs might not put the colon or might have triple backtick fences. # Some LLMs might not put the colon or might have triple backtick fences.
logger.debug( logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
)
# 1b) Check for inline JSON after "PARSABLE OUTPUT" # 1b) Check for inline JSON after "PARSABLE OUTPUT"
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$" pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
@ -226,47 +217,35 @@ class BaseModelClient:
if not matches: if not matches:
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks) # 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
logger.debug( logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
)
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})" pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
matches = re.search(pattern_asterisk, raw_response, re.DOTALL) matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
if not matches: if not matches:
logger.debug( logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
)
# 2) If still no match, check for triple-backtick code fences containing JSON # 2) If still no match, check for triple-backtick code fences containing JSON
if not matches: if not matches:
code_fence_pattern = r"```json\n(.*?)\n```" code_fence_pattern = r"```json\n(.*?)\n```"
matches = re.search(code_fence_pattern, raw_response, re.DOTALL) matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
if matches: if matches:
logger.debug( logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
)
# 2b) Also try plain ``` code fences without json marker # 2b) Also try plain ``` code fences without json marker
if not matches: if not matches:
code_fence_plain = r"```\n(.*?)\n```" code_fence_plain = r"```\n(.*?)\n```"
matches = re.search(code_fence_plain, raw_response, re.DOTALL) matches = re.search(code_fence_plain, raw_response, re.DOTALL)
if matches: if matches:
logger.debug( logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
)
# 2c) Try to find bare JSON object anywhere in the response # 2c) Try to find bare JSON object anywhere in the response
if not matches: if not matches:
logger.debug( logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
)
# Look for a JSON object that contains "orders" key # Look for a JSON object that contains "orders" key
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})' bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
matches = re.search(bare_json_pattern, raw_response, re.DOTALL) matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
if matches: if matches:
logger.debug( logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
)
# 3) Attempt to parse JSON if we found anything # 3) Attempt to parse JSON if we found anything
json_text = None json_text = None
@ -283,9 +262,7 @@ class BaseModelClient:
json_text = json_text.strip() json_text = json_text.strip()
if not json_text: if not json_text:
logger.debug( logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
)
return None return None
# 3a) Try JSON loading # 3a) Try JSON loading
@ -293,14 +270,12 @@ class BaseModelClient:
data = json.loads(json_text) data = json.loads(json_text)
return data.get("orders", None) return data.get("orders", None)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning( logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
)
# Try to fix common JSON issues # Try to fix common JSON issues
try: try:
# Remove trailing commas # Remove trailing commas
fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text) fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
# Fix single quotes to double quotes # Fix single quotes to double quotes
fixed_json = fixed_json.replace("'", '"') fixed_json = fixed_json.replace("'", '"')
# Try parsing again # Try parsing again
@ -308,14 +283,12 @@ class BaseModelClient:
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}") logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
return data.get("orders", None) return data.get("orders", None)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
)
# Try to remove inline comments (// style) # Try to remove inline comments (// style)
try: try:
# Remove // comments from each line # Remove // comments from each line
lines = json_text.split('\n') lines = json_text.split("\n")
cleaned_lines = [] cleaned_lines = []
for line in lines: for line in lines:
# Find // that's not inside quotes # Find // that's not inside quotes
@ -326,12 +299,12 @@ class BaseModelClient:
if escape_next: if escape_next:
escape_next = False escape_next = False
continue continue
if char == '\\': if char == "\\":
escape_next = True escape_next = True
continue continue
if char == '"' and not escape_next: if char == '"' and not escape_next:
in_quotes = not in_quotes in_quotes = not in_quotes
if not in_quotes and line[i:i+2] == '//': if not in_quotes and line[i : i + 2] == "//":
comment_pos = i comment_pos = i
break break
@ -342,17 +315,15 @@ class BaseModelClient:
cleaned_line = line cleaned_line = line
cleaned_lines.append(cleaned_line) cleaned_lines.append(cleaned_line)
comment_free_json = '\n'.join(cleaned_lines) comment_free_json = "\n".join(cleaned_lines)
# Also remove trailing commas after comment removal # Also remove trailing commas after comment removal
comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json) comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
data = json.loads(comment_free_json) data = json.loads(comment_free_json)
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}") logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
return data.get("orders", None) return data.get("orders", None)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
)
# 3b) Attempt bracket fallback: we look for the substring after "orders" # 3b) Attempt bracket fallback: we look for the substring after "orders"
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors. # E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
@ -366,16 +337,12 @@ class BaseModelClient:
if isinstance(moves, list): if isinstance(moves, list):
return moves return moves
except Exception as e2: except Exception as e2:
logger.warning( logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
)
# If all attempts failed # If all attempts failed
return None return None
def _validate_orders( def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
self, moves: List[str], possible_orders: Dict[str, List[str]]
) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
""" """
Filter out invalid moves, fill missing with HOLD, else fallback. Filter out invalid moves, fill missing with HOLD, else fallback.
Returns a tuple: (validated_moves, invalid_moves_found) Returns a tuple: (validated_moves, invalid_moves_found)
@ -405,15 +372,15 @@ class BaseModelClient:
for loc, orders_list in possible_orders.items(): for loc, orders_list in possible_orders.items():
if loc not in used_locs and orders_list: if loc not in used_locs and orders_list:
hold_candidates = [o for o in orders_list if o.endswith("H")] hold_candidates = [o for o in orders_list if o.endswith("H")]
validated.append( validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
hold_candidates[0] if hold_candidates else orders_list[0]
)
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM) if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.") logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
return self.fallback_orders(possible_orders), [] return self.fallback_orders(possible_orders), []
elif not validated and invalid_moves_found: # All LLM moves were invalid elif not validated and invalid_moves_found: # All LLM moves were invalid
logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}") logger.warning(
f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
)
# We return empty list for validated, but the invalid_moves_found list is populated # We return empty list for validated, but the invalid_moves_found list is populated
return self.fallback_orders(possible_orders), invalid_moves_found return self.fallback_orders(possible_orders), invalid_moves_found
@ -444,7 +411,6 @@ class BaseModelClient:
agent_relationships: Optional[Dict[str, str]] = None, agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None, # Added agent_private_diary_str: Optional[str] = None, # Added
) -> str: ) -> str:
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir) instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
context = self.build_context_prompt( context = self.build_context_prompt(
@ -508,7 +474,15 @@ class BaseModelClient:
unanswered_messages += "\nNo urgent messages requiring direct responses.\n" unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
final_prompt = context + unanswered_messages + "\n\n" + instructions final_prompt = context + unanswered_messages + "\n\n" + instructions
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey') final_prompt = (
final_prompt.replace("AUSTRIA", "Austria")
.replace("ENGLAND", "England")
.replace("FRANCE", "France")
.replace("GERMANY", "Germany")
.replace("ITALY", "Italy")
.replace("RUSSIA", "Russia")
.replace("TURKEY", "Turkey")
)
return final_prompt return final_prompt
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
@ -524,7 +498,6 @@ class BaseModelClient:
agent_relationships: Optional[Dict[str, str]] = None, agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None, # Added agent_private_diary_str: Optional[str] = None, # Added
) -> str: ) -> str:
prompt = self.build_planning_prompt( prompt = self.build_planning_prompt(
game, game,
board_state, board_state,
@ -545,7 +518,7 @@ class BaseModelClient:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game_phase, # Use game_phase for logging phase=game_phase, # Use game_phase for logging
response_type='plan_reply', # Changed from 'plan' to avoid confusion response_type="plan_reply", # Changed from 'plan' to avoid confusion
) )
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}") logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
return raw_response return raw_response
@ -592,7 +565,7 @@ class BaseModelClient:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game_phase, phase=game_phase,
response_type='negotiation', # For run_llm_and_log's internal context response_type="negotiation", # For run_llm_and_log's internal context
) )
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}") logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
@ -600,13 +573,10 @@ class BaseModelClient:
if config.USE_UNFORMATTED_PROMPTS: if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency # Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
# Format the natural language response into structured JSON # Format the natural language response into structured JSON
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
raw_response, raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
FORMAT_CONVERSATION,
power_name=power_name,
phase=game_phase,
log_file_path=log_file_path
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -632,10 +602,10 @@ class BaseModelClient:
# Original parsing logic as fallback # Original parsing logic as fallback
if not parsed_messages: if not parsed_messages:
# Attempt to find blocks enclosed in {{...}} # Attempt to find blocks enclosed in {{...}}
double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL) double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
if double_brace_blocks: if double_brace_blocks:
# If {{...}} blocks are found, assume each is a self-contained JSON object # If {{...}} blocks are found, assume each is a self-contained JSON object
json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks]) json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
else: else:
# If no {{...}} blocks, look for ```json ... ``` markdown blocks # If no {{...}} blocks, look for ```json ... ``` markdown blocks
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL) code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
@ -650,10 +620,10 @@ class BaseModelClient:
json_blocks = [json.dumps(data)] json_blocks = [json.dumps(data)]
except json.JSONDecodeError: except json.JSONDecodeError:
# If parsing the whole block fails, fall back to regex for individual objects # If parsing the whole block fails, fall back to regex for individual objects
json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL) json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
else: else:
# If no markdown block, fall back to regex for any JSON object in the response # If no markdown block, fall back to regex for any JSON object in the response
json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL) json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
# Process json_blocks if we have them from fallback parsing # Process json_blocks if we have them from fallback parsing
if not parsed_messages and json_blocks: if not parsed_messages and json_blocks:
@ -661,7 +631,7 @@ class BaseModelClient:
try: try:
cleaned_block = block.strip() cleaned_block = block.strip()
# Attempt to fix common JSON issues like trailing commas before parsing # Attempt to fix common JSON issues like trailing commas before parsing
cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block) cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
parsed_message = json.loads(cleaned_block) parsed_message = json.loads(cleaned_block)
parsed_messages.append(parsed_message) parsed_messages.append(parsed_message)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -710,7 +680,7 @@ class BaseModelClient:
response_type="negotiation_message", response_type="negotiation_message",
raw_input_prompt=raw_input_prompt, raw_input_prompt=raw_input_prompt,
raw_response=raw_response, raw_response=raw_response,
success=success_status success=success_status,
) )
return messages_to_return return messages_to_return
@ -770,7 +740,7 @@ class BaseModelClient:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='plan_generation', # More specific type for run_llm_and_log context response_type="plan_generation", # More specific type for run_llm_and_log context
) )
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}") logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
# No parsing needed for the plan, return the raw string # No parsing needed for the plan, return the raw string
@ -790,7 +760,7 @@ class BaseModelClient:
response_type="plan_generation", # Specific type for CSV logging response_type="plan_generation", # Specific type for CSV logging
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
raw_response=raw_plan_response, raw_response=raw_plan_response,
success=success_status success=success_status,
# token_usage and cost can be added later # token_usage and cost can be added later
) )
return plan_to_return return plan_to_return
@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
): ):
super().__init__(model_name, prompts_dir=prompts_dir) super().__init__(model_name, prompts_dir=prompts_dir)
self.base_url = ( self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
self.api_key = api_key or os.environ.get("OPENAI_API_KEY") self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key: if not self.api_key:
@ -832,11 +798,7 @@ class OpenAIClient(BaseModelClient):
inject_random_seed: bool = True, inject_random_seed: bool = True,
) -> str: ) -> str:
try: try:
system_prompt_content = ( system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:" prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
temperature=temperature, temperature=temperature,
) )
if not response.content: if not response.content:
logger.warning( logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
)
return "" return ""
return response.content[0].text.strip() if response.content else "" return response.content[0].text.strip() if response.content else ""
except json.JSONDecodeError as json_err: except json.JSONDecodeError as json_err:
logger.error( logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
)
return "" return ""
except Exception as e: except Exception as e:
logger.error( logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
return "" return ""
@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:" full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
try: try:
generation_config = genai.types.GenerationConfig( generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
temperature=temperature,
max_output_tokens=self.max_tokens
)
response = await self.client.generate_content_async( response = await self.client.generate_content_async(
contents=full_prompt, contents=full_prompt,
generation_config=generation_config, generation_config=generation_config,
) )
if not response or not response.text: if not response or not response.text:
logger.warning( logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
)
return "" return ""
return response.text.strip() return response.text.strip()
except Exception as e: except Exception as e:
@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
def __init__(self, model_name: str, prompts_dir: Optional[str] = None): def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir) super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("DEEPSEEK_API_KEY") self.api_key = os.environ.get("DEEPSEEK_API_KEY")
self.client = AsyncDeepSeekOpenAI( self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
api_key=self.api_key,
base_url="https://api.deepseek.com/"
)
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
try: try:
@ -986,9 +934,7 @@ class DeepSeekClient(BaseModelClient):
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}") logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
if not response or not response.choices: if not response or not response.choices:
logger.warning( logger.warning(f"[{self.model_name}] No valid response in generate_response.")
f"[{self.model_name}] No valid response in generate_response."
)
return "" return ""
content = response.choices[0].message.content.strip() content = response.choices[0].message.content.strip()
@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
return content return content
except Exception as e: except Exception as e:
logger.error( logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
return "" return ""
@ -1037,19 +981,14 @@ class OpenAIResponsesClient(BaseModelClient):
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
} }
headers = { headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
# Make the API call using aiohttp # Make the API call using aiohttp
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(self.base_url, json=payload, headers=headers) as response: async with session.post(self.base_url, json=payload, headers=headers) as response:
if response.status != 200: if response.status != 200:
error_text = await response.text() error_text = await response.text()
logger.error( logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
f"[{self.model_name}] API error (status {response.status}): {error_text}"
)
return "" return ""
response_data = await response.json() response_data = await response.json()
@ -1059,24 +998,18 @@ class OpenAIResponsesClient(BaseModelClient):
try: try:
outputs = response_data.get("output", []) outputs = response_data.get("output", [])
if len(outputs) < 2: if len(outputs) < 2:
logger.warning( logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
)
return "" return ""
# The message is typically in the second output item # The message is typically in the second output item
message_output = outputs[1] message_output = outputs[1]
if message_output.get("type") != "message": if message_output.get("type") != "message":
logger.warning( logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
)
return "" return ""
content_list = message_output.get("content", []) content_list = message_output.get("content", [])
if not content_list: if not content_list:
logger.warning( logger.warning(f"[{self.model_name}] Empty content list in message output")
f"[{self.model_name}] Empty content list in message output"
)
return "" return ""
# Look for the content item with type 'output_text' # Look for the content item with type 'output_text'
@ -1087,28 +1020,20 @@ class OpenAIResponsesClient(BaseModelClient):
break break
if not text_content: if not text_content:
logger.warning( logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
)
return "" return ""
return text_content.strip() return text_content.strip()
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
logger.error( logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
)
return "" return ""
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error( logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
f"[{self.model_name}] HTTP client error in generate_response: {e}"
)
return "" return ""
except Exception as e: except Exception as e:
logger.error( logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
return "" return ""
@ -1129,10 +1054,7 @@ class OpenRouterClient(BaseModelClient):
if not self.api_key: if not self.api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is required") raise ValueError("OPENROUTER_API_KEY environment variable is required")
self.client = AsyncOpenAI( self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
base_url="https://openrouter.ai/api/v1",
api_key=self.api_key
)
logger.debug(f"[{self.model_name}] Initialized OpenRouter client") logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
@ -1150,10 +1072,7 @@ class OpenRouterClient(BaseModelClient):
# Prepare standard OpenAI-compatible request # Prepare standard OpenAI-compatible request
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=[ messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta}
],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
temperature=temperature, temperature=temperature,
) )
@ -1245,7 +1164,7 @@ class TogetherAIClient(BaseModelClient):
############################################################################## ##############################################################################
# RequestsOpenAIClient sync requests, wrapped async (original + api_key) # RequestsOpenAIClient sync requests, wrapped async (original + api_key)
############################################################################## ##############################################################################
import requests, asyncio
class RequestsOpenAIClient(BaseModelClient): class RequestsOpenAIClient(BaseModelClient):
""" """
@ -1266,11 +1185,7 @@ class RequestsOpenAIClient(BaseModelClient):
if not self.api_key: if not self.api_key:
raise ValueError("OPENAI_API_KEY missing and no inline key provided") raise ValueError("OPENAI_API_KEY missing and no inline key provided")
self.base_url = ( self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
).rstrip("/")
self.endpoint = f"{self.base_url}/chat/completions" self.endpoint = f"{self.base_url}/chat/completions"
@ -1291,11 +1206,7 @@ class RequestsOpenAIClient(BaseModelClient):
temperature: float = 0.0, temperature: float = 0.0,
inject_random_seed: bool = True, inject_random_seed: bool = True,
) -> str: ) -> str:
system_prompt_content = ( system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
payload = { payload = {
"model": self.model_name, "model": self.model_name,
@ -1322,7 +1233,6 @@ class RequestsOpenAIClient(BaseModelClient):
return "" return ""
############################################################################## ##############################################################################
# 3) Factory to Load Model Client # 3) Factory to Load Model Client
############################################################################## ##############################################################################
@ -1332,6 +1242,7 @@ class ModelSpec(NamedTuple):
base: Optional[str] # 'https://proxy.foo' base: Optional[str] # 'https://proxy.foo'
key: Optional[str] # 'sk-…' (may be None) key: Optional[str] # 'sk-…' (may be None)
def _parse_model_spec(raw: str) -> ModelSpec: def _parse_model_spec(raw: str) -> ModelSpec:
""" """
Splits once on '#' (API key) and once on '@' (base URL). A leading Splits once on '#' (API key) and once on '@' (base URL). A leading
@ -1441,8 +1352,6 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
return OpenAIClient(model_name, prompts_dir, base_url) return OpenAIClient(model_name, prompts_dir, base_url)
############################################################################## ##############################################################################
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient) # 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
############################################################################## ##############################################################################
@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
visible = [] visible = []
for msg in conversation_messages: for msg in conversation_messages:
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage # GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
if ( if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
msg["recipient"] == "ALL"
or msg["recipient"] == "GLOBAL"
or msg["sender"] == power_name
or msg["recipient"] == power_name
):
visible.append(msg) visible.append(msg)
return visible # already in chronological order if appended that way return visible # already in chronological order if appended that way

View file

@ -11,25 +11,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None: def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports.""" """A local copy of the helper from agent.py to avoid circular imports."""
import os import os
try: try:
if prompts_dir: if prompts_dir:
filepath = os.path.join(prompts_dir, filename) filepath = os.path.join(prompts_dir, filename)
else: else:
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts') default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename) filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
return f.read() return f.read()
except Exception as e: except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}") logger.error(f"Error loading prompt file {filepath}: {e}")
return None return None
async def run_diary_consolidation( async def run_diary_consolidation(
agent: 'DiplomacyAgent', agent: "DiplomacyAgent",
game: "Game", game: "Game",
log_file_path: str, log_file_path: str,
entries_to_keep_unsummarized: int = 6, entries_to_keep_unsummarized: int = 6,
@ -39,70 +42,42 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones. Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class. This is the logic moved from the DiplomacyAgent class.
""" """
logger.info( logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
f"[{agent.power_name}] CONSOLIDATION START — "
f"{len(agent.full_private_diary)} total full entries"
)
full_entries = [ full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
e for e in agent.full_private_diary
if not e.startswith("[CONSOLIDATED HISTORY]")
]
if len(full_entries) <= entries_to_keep_unsummarized: if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
logger.info( logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
"skipping consolidation"
)
return return
boundary_entry = full_entries[-entries_to_keep_unsummarized] boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry) match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match: if not match:
logger.error( logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
f"[{agent.power_name}] Could not parse year from boundary entry; "
"aborting consolidation"
)
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
return return
cutoff_year = int(match.group(1)) cutoff_year = int(match.group(1))
logger.info( logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
)
def _entry_year(entry: str) -> int | None: def _entry_year(entry: str) -> int | None:
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry) m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None return int(m.group(1)) if m else None
entries_to_summarize = [ entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
e for e in full_entries entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
]
entries_to_keep = [
e for e in full_entries
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
]
logger.info( logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
f"keeping {len(entries_to_keep)} recent entries verbatim"
)
if not entries_to_summarize: if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
logger.warning( logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
f"[{agent.power_name}] No eligible entries to summarise; "
"context diary left unchanged"
)
return return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template: if not prompt_template:
logger.error( logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
)
return return
prompt = prompt_template.format( prompt = prompt_template.format(
@ -132,23 +107,14 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}" new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE" success_flag = "TRUE"
logger.info( logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
f"[{agent.power_name}] Consolidation complete — "
f"{len(agent.private_diary)} context entries now"
)
except Exception as exc: except Exception as exc:
logger.error( logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
)
finally: finally:
log_llm_response( log_llm_response(
log_file_path=log_file_path, log_file_path=log_file_path,
model_name=( model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
consolidation_client.model_name
if consolidation_client is not None
else agent.client.model_name
),
power_name=agent.power_name, power_name=agent.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type="diary_consolidation", response_type="diary_consolidation",

View file

@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses. Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
""" """
import json
import logging import logging
import os from typing import Optional
from typing import Dict, Any, Optional
from pathlib import Path from pathlib import Path
# Import logging function and model configuration # Import logging function and model configuration
from .utils import log_llm_response, get_special_models from .utils import log_llm_response, get_special_models
# Import client loading function # Import client loading function
from .clients import load_model_client from .clients import load_model_client
@ -26,11 +25,7 @@ FORMAT_ORDER_DIARY = "order_diary"
async def format_with_gemini_flash( async def format_with_gemini_flash(
raw_response: str, raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
format_type: str,
power_name: Optional[str] = None,
phase: Optional[str] = None,
log_file_path: Optional[str] = None
) -> str: ) -> str:
""" """
Format a natural language response into required JSON structure using Gemini Flash. Format a natural language response into required JSON structure using Gemini Flash.
@ -52,7 +47,7 @@ async def format_with_gemini_flash(
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt", FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
FORMAT_ORDERS: "formatting/format_orders.txt", FORMAT_ORDERS: "formatting/format_orders.txt",
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt", FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt" FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
} }
if format_type not in format_prompts: if format_type not in format_prompts:
@ -63,7 +58,7 @@ async def format_with_gemini_flash(
if not prompt_file.exists(): if not prompt_file.exists():
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}") raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
with open(prompt_file, 'r') as f: with open(prompt_file, "r") as f:
format_prompt = f.read() format_prompt = f.read()
# Replace placeholder with actual response # Replace placeholder with actual response
@ -87,7 +82,7 @@ async def format_with_gemini_flash(
formatted_response = await formatter_client.generate_response( formatted_response = await formatter_client.generate_response(
prompt=format_prompt, prompt=format_prompt,
temperature=0, # Deterministic formatting temperature=0, # Deterministic formatting
inject_random_seed=False # No need for random seed in formatting inject_random_seed=False, # No need for random seed in formatting
) )
if not formatted_response: if not formatted_response:
@ -116,7 +111,7 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}", response_type=f"format_{format_type}",
raw_input_prompt=format_prompt, raw_input_prompt=format_prompt,
raw_response=formatted_response, raw_response=formatted_response,
success="Success" success="Success",
) )
return formatted_response return formatted_response
@ -134,7 +129,7 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}", response_type=f"format_{format_type}",
raw_input_prompt=format_prompt, raw_input_prompt=format_prompt,
raw_response=f"ERROR: {str(e)}", raw_response=f"ERROR: {str(e)}",
success=f"Failure: {type(e).__name__}" success=f"Failure: {type(e).__name__}",
) )
# Return empty structure based on format type # Return empty structure based on format type

View file

@ -3,7 +3,6 @@ import logging
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pprint import pformat
logger = logging.getLogger("utils") logger = logging.getLogger("utils")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -24,12 +23,8 @@ class Phase:
plans: Dict[str, str] = field(default_factory=dict) plans: Dict[str, str] = field(default_factory=dict)
messages: List[Message] = field(default_factory=list) messages: List[Message] = field(default_factory=list)
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict) submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
orders_by_power: Dict[str, List[str]] = field( orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
default_factory=lambda: defaultdict(list) results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
)
results_by_power: Dict[str, List[List[str]]] = field(
default_factory=lambda: defaultdict(list)
)
# NEW: Store phase-end summaries provided by each power # NEW: Store phase-end summaries provided by each power
phase_summaries: Dict[str, str] = field(default_factory=dict) phase_summaries: Dict[str, str] = field(default_factory=dict)
# NEW: Store experience/journal updates from each power for this phase # NEW: Store experience/journal updates from each power for this phase
@ -39,9 +34,7 @@ class Phase:
self.plans[power_name] = plan self.plans[power_name] = plan
def add_message(self, sender: str, recipient: str, content: str): def add_message(self, sender: str, recipient: str, content: str):
self.messages.append( self.messages.append(Message(sender=sender, recipient=recipient, content=content))
Message(sender=sender, recipient=recipient, content=content)
)
def add_orders(self, power: str, orders: List[str], results: List[List[str]]): def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
self.orders_by_power[power].extend(orders) self.orders_by_power[power].extend(orders)
@ -110,14 +103,10 @@ class GameHistory:
phase.plans[power_name] = plan phase.plans[power_name] = plan
logger.debug(f"Added plan for {power_name} in {phase_name}") logger.debug(f"Added plan for {power_name} in {phase_name}")
def add_message( def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
self, phase_name: str, sender: str, recipient: str, message_content: str
):
phase = self._get_phase(phase_name) phase = self._get_phase(phase_name)
if phase: if phase:
message = Message( message = Message(sender=sender, recipient=recipient, content=message_content)
sender=sender, recipient=recipient, content=message_content
)
phase.messages.append(message) phase.messages.append(message)
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}") logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
@ -153,10 +142,6 @@ class GameHistory:
return {} return {}
return self.phases[-1].plans return self.phases[-1].plans
def get_order_history_for_prompt( def get_order_history_for_prompt(
self, self,
game: "Game", game: "Game",
@ -194,10 +179,7 @@ class GameHistory:
return str(tag).strip().lower() return str(tag).strip().lower()
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()} engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
eng2code = { eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
}
norm = game.map.norm norm = game.map.norm
out_lines = ["**ORDER HISTORY (Recent Rounds)**"] out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
@ -264,13 +246,6 @@ class GameHistory:
return "\n(No orders were issued in recent history)\n" return "\n(No orders were issued in recent history)\n"
return "\n".join(out_lines) return "\n".join(out_lines)
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str: def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None current_phase: Optional[Phase] = None
for phase_obj in self.phases: for phase_obj in self.phases:
@ -325,11 +300,7 @@ class GameHistory:
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name): if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
# Skip if sender is this power (don't need to respond to own messages) # Skip if sender is this power (don't need to respond to own messages)
if msg.sender != power_name: if msg.sender != power_name:
messages_to_power.append({ messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
'sender': msg.sender,
'content': msg.content,
'phase': phase.name
})
# Add debug logging # Add debug logging
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases") logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
@ -361,10 +332,10 @@ class GameHistory:
for msg in phase.messages: for msg in phase.messages:
# Handle both Message objects and dict objects # Handle both Message objects and dict objects
if isinstance(msg, Message): if isinstance(msg, Message):
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']: if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
sender_messages.append(msg) sender_messages.append(msg)
else: # Assume dict else: # Assume dict
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']: if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
sender_messages.append(msg) sender_messages.append(msg)
# Check for responses in this and next phases # Check for responses in this and next phases
@ -374,8 +345,8 @@ class GameHistory:
recipient = msg.recipient recipient = msg.recipient
msg_content = msg.content msg_content = msg.content
else: else:
recipient = msg['recipient'] recipient = msg["recipient"]
msg_content = msg['content'] msg_content = msg["content"]
# Look for responses in current phase and next phases # Look for responses in current phase and next phases
found_response = False found_response = False
@ -387,12 +358,14 @@ class GameHistory:
response_msgs = [] response_msgs = []
for m in check_phase.messages: for m in check_phase.messages:
if isinstance(m, Message): if isinstance(m, Message):
if m.sender == recipient and (m.recipient == sender_name or if m.sender == recipient and (
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)): m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
):
response_msgs.append(m) response_msgs.append(m)
else: # Assume dict else: # Assume dict
if m['sender'] == recipient and (m['recipient'] == sender_name or if m["sender"] == recipient and (
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))): m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
):
response_msgs.append(m) response_msgs.append(m)
if response_msgs: if response_msgs:
@ -402,9 +375,6 @@ class GameHistory:
if not found_response: if not found_response:
if recipient not in ignored_by_power: if recipient not in ignored_by_power:
ignored_by_power[recipient] = [] ignored_by_power[recipient] = []
ignored_by_power[recipient].append({ ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
'phase': phase.name,
'content': msg_content
})
return ignored_by_power return ignored_by_power

View file

@ -3,7 +3,7 @@ import logging
import os import os
import json import json
import asyncio import asyncio
from typing import Dict, List, Tuple, Optional, Any from typing import Dict, Tuple, Optional, Any
from argparse import Namespace from argparse import Namespace
from diplomacy import Game from diplomacy import Game
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
# --- Serialization / Deserialization --- # --- Serialization / Deserialization ---
def serialize_agent(agent: DiplomacyAgent) -> dict: def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary.""" """Converts an agent object to a JSON-serializable dictionary."""
return { return {
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary, "private_diary": agent.private_diary,
} }
def deserialize_agent(
agent_data: dict, def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
prompts_dir: Optional[str] = None,
*,
override_model_id: Optional[str] = None
) -> DiplomacyAgent:
""" """
Recreates an agent object from a dictionary. Recreates an agent object from a dictionary.
@ -70,6 +67,7 @@ def deserialize_agent(
# game_logic.py # game_logic.py
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments _PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
def _next_phase_name(short: str) -> str: def _next_phase_name(short: str) -> str:
""" """
Return the Diplomacy phase string that chronologically follows *short*. Return the Diplomacy phase string that chronologically follows *short*.
@ -89,13 +87,9 @@ def _next_phase_name(short: str) -> str:
else: # winter→ spring, next year else: # winter→ spring, next year
return f"S{year + 1}M" return f"S{year + 1}M"
def save_game_state( def save_game_state(
game: Game, game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
agents: Dict[str, DiplomacyAgent],
game_history: GameHistory,
output_path: str,
run_config: Namespace,
completed_phase_name: str
): ):
""" """
Serialise the entire game to JSON, preserving per-phase custom metadata Serialise the entire game to JSON, preserving per-phase custom metadata
@ -130,9 +124,7 @@ def save_game_state(
if extras: if extras:
previous_phase_extras[phase["name"]] = extras previous_phase_extras[phase["name"]] = extras
except Exception as exc: except Exception as exc:
logger.warning( logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
"Could not load previous save to retain metadata: %s", exc, exc_info=True
)
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. # # 2. Build the fresh base structure from the diplomacy library. #
@ -143,11 +135,7 @@ def save_game_state(
# 3. Walk every phase and merge the metadata back in. # # 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once. # Capture the *current* snapshot of every live agent exactly once.
current_state_agents = { current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
p_name: serialize_agent(p_agent)
for p_name, p_agent in agents.items()
if not game.powers[p_name].is_eliminated()
}
for phase_block in saved_game.get("phases", []): for phase_block in saved_game.get("phases", []):
if int(phase_block["name"][1:5]) > run_config.max_year: if int(phase_block["name"][1:5]) > run_config.max_year:
@ -167,19 +155,13 @@ def save_game_state(
# Plans for this phase may be empty in non-movement phases. # Plans for this phase may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name) phase_obj = game_history._get_phase(phase_name)
phase_block["state_history_plans"] = ( phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
phase_obj.plans if phase_obj else {}
)
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. # # 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {}) saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = { saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
p_name: {"relationships": a.relationships, "goals": a.goals}
for p_name, a in agents.items()
}
# Filter out phases > max_year # Filter out phases > max_year
# saved_game["phases"] = [ # saved_game["phases"] = [
@ -191,7 +173,6 @@ def save_game_state(
logger.info("Game state saved successfully.") logger.info("Game state saved successfully.")
def load_game_state( def load_game_state(
run_dir: str, run_dir: str,
game_file_name: str, game_file_name: str,
@ -204,7 +185,7 @@ def load_game_state(
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}") raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}") logger.info(f"Loading game state from: {game_file_path}")
with open(game_file_path, 'r') as f: with open(game_file_path, "r") as f:
saved_game_data = json.load(f) saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it # If resuming, find the specified phase and truncate the data after it
@ -213,37 +194,37 @@ def load_game_state(
try: try:
# Find the index of the phase *before* the one we want to resume from. # Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase. # We will start the simulation *at* the resume_from_phase.
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase) resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase # Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase. # Note: the state saved for a given phase represents the state at the beginning of that phase.
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1] saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated. # Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"): for key in ("orders", "results", "messages"):
saved_game_data['phases'][-1].pop(key, None) saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.") logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
except StopIteration: except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M) # If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M": if resume_from_phase == "S1901M":
saved_game_data['phases'] = [] saved_game_data["phases"] = []
logger.info("Resuming from S1901M. Starting with a clean history.") logger.info("Resuming from S1901M. Starting with a clean history.")
else: else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.") raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object # Reconstruct the Game object
last_phase = saved_game_data['phases'][-1] last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys** # Wipe the data that must be regenerated **but preserve the keys**
last_phase['orders'] = {} # was dict last_phase["orders"] = {} # was dict
last_phase['results'] = {} # was dict last_phase["results"] = {} # was dict
last_phase['messages'] = [] last_phase["messages"] = []
game = from_saved_game_format(saved_game_data) game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get('phase_summaries', {}) game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data # Reconstruct agents and game history from the *last* valid phase in the data
if not saved_game_data['phases']: if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M) # This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.") logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop agents = {} # Will be created by the main loop
@ -251,10 +232,10 @@ def load_game_state(
else: else:
# We save the game state up to & including the current (uncompleted) phase. # We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase. # So we need to grab the agent state from the previous (completed) phase.
if len(saved_game_data['phases']) <= 1: if len(saved_game_data["phases"]) <= 1:
last_phase_data = {} last_phase_data = {}
else: else:
last_phase_data = saved_game_data['phases'][-2] last_phase_data = saved_game_data["phases"][-2]
# -------------------- Rebuild agents -------------------- # # -------------------- Rebuild agents -------------------- #
agents = {} agents = {}
@ -269,18 +250,11 @@ def load_game_state(
elif len(provided) == 1: elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order))) power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else: else:
raise ValueError( raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
f"Invalid --models argument: expected 1 or {len(powers_order)} "
f"items, got {len(provided)}."
)
if "state_agents" in last_phase_data: if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...") logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = ( prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
run_config.prompts_dir
if run_config and hasattr(run_config, "prompts_dir")
else None
)
for power_name, agent_data in last_phase_data["state_agents"].items(): for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name) override_id = power_model_map.get(power_name)
agents[power_name] = deserialize_agent( agents[power_name] = deserialize_agent(
@ -290,37 +264,27 @@ def load_game_state(
) )
logger.info(f"Rebuilt {len(agents)} agents.") logger.info(f"Rebuilt {len(agents)} agents.")
else: else:
raise ValueError( raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
"Cannot resume: 'state_agents' key not found in the last phase "
"of the save file."
)
# Rebuild GameHistory # Rebuild GameHistory
game_history = GameHistory() game_history = GameHistory()
logger.info("Rebuilding game history...") logger.info("Rebuilding game history...")
for phase_data in saved_game_data['phases'][:-1]: for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data['name'] phase_name = phase_data["name"]
game_history.add_phase(phase_name) game_history.add_phase(phase_name)
# Add messages # Add messages
for msg in phase_data.get('messages', []): for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message']) game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans # Add plans
if 'state_history_plans' in phase_data: if "state_history_plans" in phase_data:
for p_name, plan in phase_data['state_history_plans'].items(): for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan) game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.") logger.info("Game history rebuilt.")
return game, agents, game_history, run_config return game, agents, game_history, run_config
async def initialize_new_game( async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str
) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game.""" """Initializes agents for a new game."""
powers_order = sorted(list(ALL_POWERS)) powers_order = sorted(list(ALL_POWERS))

View file

@ -1,7 +1,6 @@
# ai_diplomacy/initialization.py # ai_diplomacy/initialization.py
import logging import logging
import json import json
import os
from typing import Optional from typing import Optional
from ..config import config from ..config import config
@ -18,10 +17,11 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def initialize_agent_state_ext( async def initialize_agent_state_ext(
agent: 'DiplomacyAgent', agent: "DiplomacyAgent",
game: 'Game', game: "Game",
game_history: 'GameHistory', game_history: "GameHistory",
log_file_path: str, log_file_path: str,
prompts_dir: Optional[str] = None, prompts_dir: Optional[str] = None,
): ):
@ -37,21 +37,20 @@ async def initialize_agent_state_ext(
try: try:
# Load the prompt template # Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS) allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir) initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template: if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.") logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return return
# Format the prompt with variables # Format the prompt with variables
initial_prompt = initial_prompt_template.format( initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
power_name=power_name,
allowed_labels_str=allowed_labels_str
)
board_state = game.get_state() if game else {} board_state = game.get_state() if game else {}
possible_orders = game.get_all_possible_orders() if game else {} possible_orders = game.get_all_possible_orders() if game else {}
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}") logger.debug(
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
)
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility # Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
# For initialization, game should always be present. # For initialization, game should always be present.
@ -76,7 +75,7 @@ async def initialize_agent_state_ext(
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=current_phase, phase=current_phase,
response_type='initialization', # Context for run_llm_and_log internal error logging response_type="initialization", # Context for run_llm_and_log internal error logging
) )
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
if config.USE_UNFORMATTED_PROMPTS: if config.USE_UNFORMATTED_PROMPTS:
# Format the natural language response into JSON # Format the natural language response into JSON
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
response, response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
FORMAT_INITIAL_STATE,
power_name=power_name,
phase=current_phase,
log_file_path=log_file_path
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -108,12 +103,16 @@ async def initialize_agent_state_ext(
# Defensive check for update_data type if parsing was initially considered successful # Defensive check for update_data type if parsing was initially considered successful
if parsed_successfully: if parsed_successfully:
if isinstance(update_data, str): if isinstance(update_data, str):
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...") logger.error(
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
)
update_data = {} # Treat as parsing failure update_data = {} # Treat as parsing failure
parsed_successfully = False parsed_successfully = False
success_status = "Failure: ParsedAsStr" success_status = "Failure: ParsedAsStr"
elif not isinstance(update_data, dict): # Expecting a dict from JSON object elif not isinstance(update_data, dict): # Expecting a dict from JSON object
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}") logger.error(
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
)
update_data = {} # Treat as parsing failure update_data = {} # Treat as parsing failure
parsed_successfully = False parsed_successfully = False
success_status = "Failure: NotADict" success_status = "Failure: NotADict"
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
initial_relationships_applied = False initial_relationships_applied = False
if parsed_successfully: if parsed_successfully:
initial_goals = update_data.get('initial_goals') or update_data.get('goals') initial_goals = update_data.get("initial_goals") or update_data.get("goals")
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships') initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
if isinstance(initial_goals, list) and initial_goals: if isinstance(initial_goals, list) and initial_goals:
agent.goals = initial_goals agent.goals = initial_goals
@ -200,7 +199,7 @@ async def initialize_agent_state_ext(
response_type="initial_state_setup", # Specific type for CSV logging response_type="initial_state_setup", # Specific type for CSV logging
raw_input_prompt=full_prompt, raw_input_prompt=full_prompt,
raw_response=response, raw_response=response,
success=success_status success=success_status,
) )
# Final log of state after initialization attempt # Final log of state after initialization attempt

View file

@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment] Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
LOGGER.info("Game._generate_phase_summary patched with narrative generation.") LOGGER.info("Game._generate_phase_summary patched with narrative generation.")

View file

@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
from diplomacy.engine.message import Message, GLOBAL from diplomacy.engine.message import Message, GLOBAL
from .agent import DiplomacyAgent from .agent import DiplomacyAgent
from .clients import load_model_client from .utils import gather_possible_orders, normalize_recipient_name
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
if TYPE_CHECKING: if TYPE_CHECKING:
from .game_history import GameHistory from .game_history import GameHistory
@ -21,9 +20,9 @@ load_dotenv()
async def conduct_negotiations( async def conduct_negotiations(
game: 'Game', game: "Game",
agents: Dict[str, DiplomacyAgent], agents: Dict[str, DiplomacyAgent],
game_history: 'GameHistory', game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]], model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str, log_file_path: str,
max_rounds: int = 3, max_rounds: int = 3,
@ -35,12 +34,8 @@ async def conduct_negotiations(
""" """
logger.info("Starting negotiation phase.") logger.info("Starting negotiation phase.")
active_powers = [ active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
logger.info(f"Active powers for negotiations: {active_powers}") logger.info(f"Active powers for negotiations: {active_powers}")
if eliminated_powers: if eliminated_powers:

View file

@ -1,14 +1,13 @@
from dotenv import load_dotenv
import logging import logging
import concurrent.futures import concurrent.futures
from typing import Dict from typing import Dict
from .clients import load_model_client
from .game_history import GameHistory from .game_history import GameHistory
from .agent import DiplomacyAgent from .agent import DiplomacyAgent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def planning_phase( async def planning_phase(
game, game,
agents: Dict[str, DiplomacyAgent], agents: Dict[str, DiplomacyAgent],
@ -20,12 +19,8 @@ async def planning_phase(
Lets each power generate a strategic plan using their DiplomacyAgent. Lets each power generate a strategic plan using their DiplomacyAgent.
""" """
logger.info(f"Starting planning phase for {game.current_short_phase}...") logger.info(f"Starting planning phase for {game.current_short_phase}...")
active_powers = [ active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
logger.info(f"Active powers for planning: {active_powers}") logger.info(f"Active powers for planning: {active_powers}")
if eliminated_powers: if eliminated_powers:
@ -35,9 +30,7 @@ async def planning_phase(
board_state = game.get_state() board_state = game.get_state()
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
max_workers=len(active_powers)
) as executor:
futures = {} futures = {}
for power_name in active_powers: for power_name in active_powers:
if power_name not in agents: if power_name not in agents:
@ -71,16 +64,14 @@ async def planning_phase(
if plan_result.startswith("Error:"): if plan_result.startswith("Error:"):
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}") logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
if power_name in model_error_stats: if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_generation_errors', 0) model_error_stats[power_name].setdefault("planning_generation_errors", 0)
model_error_stats[power_name]['planning_generation_errors'] += 1 model_error_stats[power_name]["planning_generation_errors"] += 1
else: else:
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0) model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
model_error_stats[f'{power_name}_planning_generation_errors'] += 1 model_error_stats[f"{power_name}_planning_generation_errors"] += 1
elif plan_result: elif plan_result:
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...") agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
game_history.add_plan( game_history.add_plan(game.current_short_phase, power_name, plan_result)
game.current_short_phase, power_name, plan_result
)
logger.debug(f"Added plan for {power_name} to history.") logger.debug(f"Added plan for {power_name} to history.")
else: else:
logger.warning(f"Agent {power_name} returned an empty plan.") logger.warning(f"Agent {power_name} returned an empty plan.")
@ -88,11 +79,11 @@ async def planning_phase(
except Exception as e: except Exception as e:
logger.error(f"Exception during planning result processing for {power_name}: {e}") logger.error(f"Exception during planning result processing for {power_name}: {e}")
if power_name in model_error_stats: if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_execution_errors', 0) model_error_stats[power_name].setdefault("planning_execution_errors", 0)
model_error_stats[power_name]['planning_execution_errors'] += 1 model_error_stats[power_name]["planning_execution_errors"] += 1
else: else:
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0) model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
model_error_stats[f'{power_name}_planning_execution_errors'] += 1 model_error_stats[f"{power_name}_planning_execution_errors"] += 1
logger.info("Planning phase processing complete.") logger.info("Planning phase processing complete.")
return game_history return game_history

View file

@ -14,6 +14,7 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]: def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
""" """
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces. Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
@ -27,9 +28,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
if len(loc.split("/")[0]) == 3 if len(loc.split("/")[0]) == 3
} }
graph: Dict[str, Dict[str, List[str]]] = { graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
p: {"ARMY": [], "FLEET": []} for p in provs
}
# ── helper: list every concrete variant of a province ───────────────── # ── helper: list every concrete variant of a province ─────────────────
def variants(code: str) -> List[str]: def variants(code: str) -> List[str]:
@ -55,11 +54,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
graph[src]["ARMY"].append(dest) graph[src]["ARMY"].append(dest)
# FLEETany src variant that can host a fleet is valid # FLEETany src variant that can host a fleet is valid
if any( if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
game_map.abuts("F", sv, "-", dv)
for sv in src_vers
for dv in dest_vers
):
graph[src]["FLEET"].append(dest) graph[src]["FLEET"].append(dest)
# ── tidy up duplicates / order --------------------------------------- # ── tidy up duplicates / order ---------------------------------------
@ -70,24 +65,23 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
return graph return graph
def bfs_shortest_path( def bfs_shortest_path(
graph: Dict[str, Dict[str, List[str]]], graph: Dict[str, Dict[str, List[str]]],
board_state: BoardState, board_state: BoardState,
game_map: GameMap, # Added game_map game_map: GameMap, # Added game_map
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC' start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
unit_type: str, unit_type: str,
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
) -> Optional[List[str]]: # Returns path of SHORT names ) -> Optional[List[str]]: # Returns path of SHORT names
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func.""" """Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
# Convert full start location to short province name # Convert full start location to short province name
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full) start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE. if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
start_loc_short = start_loc_short[:3] start_loc_short = start_loc_short[:3]
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key. # If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
# A simpler way for non-coastal full (like 'VIE') or already short: # A simpler way for non-coastal full (like 'VIE') or already short:
if '/' not in start_loc_full: if "/" not in start_loc_full:
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR' start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
else: # Has '/', e.g. 'STP/SC' else: # Has '/', e.g. 'STP/SC'
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP' start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
@ -119,10 +113,11 @@ def bfs_shortest_path(
queue.append((next_loc_short, new_path)) queue.append((next_loc_short, new_path))
return None return None
# --- Helper functions for context generation --- # --- Helper functions for context generation ---
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]: def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None.""" """Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
for power, unit_list in board_state.get('units', {}).items(): for power, unit_list in board_state.get("units", {}).items():
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC" for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
parts = unit_str.split(" ") parts = unit_str.split(" ")
if len(parts) == 2: if len(parts) == 2:
@ -131,24 +126,26 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
return f"{parts[0]} {location} ({power})" return f"{parts[0]} {location} ({power})"
return None return None
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]: def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
"""Returns the controlling power's name if the location is an SC, else None.""" """Returns the controlling power's name if the location is an SC, else None."""
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts # Normalize location to base province name, as SCs are tied to provinces, not specific coasts
loc_province_name = game_map.loc_name.get(location, location).upper()[:3] loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
if loc_province_name not in game_map.scs: if loc_province_name not in game_map.scs:
return None return None
for power, sc_list in board_state.get('centers', {}).items(): for power, sc_list in board_state.get("centers", {}).items():
if loc_province_name in sc_list: if loc_province_name in sc_list:
return power return power
return None # Unowned SC return None # Unowned SC
def get_shortest_path_to_friendly_unit( def get_shortest_path_to_friendly_unit(
board_state: BoardState, board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]], graph: Dict[str, Dict[str, List[str]]],
game_map: GameMap, # Added game_map game_map: GameMap, # Added game_map
power_name: str, power_name: str,
start_unit_loc_full: str, start_unit_loc_full: str,
start_unit_type: str start_unit_type: str,
) -> Optional[Tuple[str, List[str]]]: ) -> Optional[Tuple[str, List[str]]]:
"""Finds the shortest path to any friendly unit of the same power.""" """Finds the shortest path to any friendly unit of the same power."""
@ -183,7 +180,7 @@ def get_nearest_enemy_units(
power_name: str, power_name: str,
start_unit_loc_full: str, start_unit_loc_full: str,
start_unit_type: str, start_unit_type: str,
n: int = 3 n: int = 3,
) -> List[Tuple[str, List[str]]]: ) -> List[Tuple[str, List[str]]]:
"""Finds up to N nearest enemy units, sorted by path length.""" """Finds up to N nearest enemy units, sorted by path length."""
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names) enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
@ -205,9 +202,9 @@ def get_nearest_enemy_units(
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full: for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full) target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
if '/' in target_enemy_loc_short: if "/" in target_enemy_loc_short:
target_enemy_loc_short = target_enemy_loc_short[:3] target_enemy_loc_short = target_enemy_loc_short[:3]
if '/' not in target_enemy_loc_full: if "/" not in target_enemy_loc_full:
target_enemy_loc_short = target_enemy_loc_full[:3] target_enemy_loc_short = target_enemy_loc_full[:3]
else: else:
target_enemy_loc_short = target_enemy_loc_full[:3] target_enemy_loc_short = target_enemy_loc_full[:3]
@ -281,7 +278,7 @@ def get_adjacent_territory_details(
board_state: BoardState, board_state: BoardState,
unit_loc_full: str, # The location of the unit whose adjacencies we're checking unit_loc_full: str, # The location of the unit whose adjacencies we're checking
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
graph: Dict[str, Dict[str, List[str]]] graph: Dict[str, Dict[str, List[str]]],
) -> str: ) -> str:
"""Generates a string describing adjacent territories and units that can interact with them.""" """Generates a string describing adjacent territories and units that can interact with them."""
output_lines: List[str] = [] output_lines: List[str] = []
@ -291,9 +288,9 @@ def get_adjacent_territory_details(
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts) # For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values. # THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full) unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
if '/' in unit_loc_short: if "/" in unit_loc_short:
unit_loc_short = unit_loc_short[:3] unit_loc_short = unit_loc_short[:3]
if '/' not in unit_loc_full: if "/" not in unit_loc_full:
unit_loc_short = unit_loc_full[:3] unit_loc_short = unit_loc_full[:3]
else: else:
unit_loc_short = unit_loc_full[:3] unit_loc_short = unit_loc_full[:3]
@ -308,11 +305,11 @@ def get_adjacent_territory_details(
continue continue
processed_adj_provinces.add(adj_loc_short) processed_adj_provinces.add(adj_loc_short)
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper() adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND': if adj_loc_type == "COAST" or adj_loc_type == "LAND":
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST' adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
elif adj_loc_type == 'WATER': elif adj_loc_type == "WATER":
adj_loc_type_display = 'WATER' adj_loc_type_display = "WATER"
else: # SHUT etc. else: # SHUT etc.
adj_loc_type_display = adj_loc_type adj_loc_type_display = adj_loc_type
@ -330,8 +327,7 @@ def get_adjacent_territory_details(
# "Can support/move to" - Simplified: list units in *further* adjacent provinces # "Can support/move to" - Simplified: list units in *further* adjacent provinces
# A true "can support/move to" would require checking possible orders of those further units. # A true "can support/move to" would require checking possible orders of those further units.
# further_adj_provinces are short names from the graph # further_adj_provinces are short names from the graph
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \ further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
graph.get(adj_loc_short, {}).get('FLEET', [])
supporting_units_info = [] supporting_units_info = []
processed_further_provinces = set() processed_further_provinces = set()
@ -388,12 +384,12 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
continue continue
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F' unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET" unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3] loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper() loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
if loc_type_short == 'COAST' or loc_type_short == 'LAND': if loc_type_short == "COAST" or loc_type_short == "LAND":
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST' loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
else: else:
loc_type_display = loc_type_short loc_type_display = loc_type_short
@ -401,7 +397,7 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">') current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
# Unit Information section # Unit Information section
current_unit_lines.append(' <UnitInformation>') current_unit_lines.append(" <UnitInformation>")
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full) sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})" header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
if sc_owner_at_loc == power_name: if sc_owner_at_loc == power_name:
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
header_content += f" (SC controlled by {sc_owner_at_loc})" header_content += f" (SC controlled by {sc_owner_at_loc})"
current_unit_lines.append(f" {header_content}") current_unit_lines.append(f" {header_content}")
current_unit_lines.append(f" Units present: {unit_str_full}") current_unit_lines.append(f" Units present: {unit_str_full}")
current_unit_lines.append(' </UnitInformation>') current_unit_lines.append(" </UnitInformation>")
# Possible moves section # Possible moves section
current_unit_lines.append(' <PossibleMoves>') current_unit_lines.append(" <PossibleMoves>")
current_unit_lines.append(" Possible moves:") current_unit_lines.append(" Possible moves:")
for order_str in unit_specific_possible_orders: for order_str in unit_specific_possible_orders:
current_unit_lines.append(f" {order_str}") current_unit_lines.append(f" {order_str}")
current_unit_lines.append(' </PossibleMoves>') current_unit_lines.append(" </PossibleMoves>")
# Nearest enemy units section # Nearest enemy units section
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3) enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestEnemyUnits>') current_unit_lines.append(" <NearestEnemyUnits>")
if enemy_units_info: if enemy_units_info:
current_unit_lines.append(" Nearest units (not ours):") current_unit_lines.append(" Nearest units (not ours):")
for enemy_unit_str, enemy_path_short in enemy_units_info: for enemy_unit_str, enemy_path_short in enemy_units_info:
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]") current_unit_lines.append(
f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
)
else: else:
current_unit_lines.append(" Nearest units (not ours): None found") current_unit_lines.append(" Nearest units (not ours): None found")
current_unit_lines.append(' </NearestEnemyUnits>') current_unit_lines.append(" </NearestEnemyUnits>")
# Nearest supply centers (not controlled by us) section # Nearest supply centers (not controlled by us) section
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3) uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>') current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
if uncontrolled_scs_info: if uncontrolled_scs_info:
current_unit_lines.append(" Nearest supply centers (not controlled by us):") current_unit_lines.append(" Nearest supply centers (not controlled by us):")
for sc_str, dist, sc_path_short in uncontrolled_scs_info: for sc_str, dist, sc_path_short in uncontrolled_scs_info:
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]") current_unit_lines.append(
f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
)
else: else:
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found") current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>') current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
# Adjacent territories details section # Adjacent territories details section
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph) adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
current_unit_lines.append(' <AdjacentTerritories>') current_unit_lines.append(" <AdjacentTerritories>")
if adj_details_str: if adj_details_str:
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):") current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
# Assuming adj_details_str is already formatted with newlines and indentation for its content # Assuming adj_details_str is already formatted with newlines and indentation for its content
# We might need to indent adj_details_str if it's a single block of text # We might need to indent adj_details_str if it's a single block of text
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines # For now, let's add a standard indent to each line of adj_details_str if it contains newlines
if '\n' in adj_details_str: if "\n" in adj_details_str:
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')]) indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
current_unit_lines.append(indented_adj_details) current_unit_lines.append(indented_adj_details)
else: else:
current_unit_lines.append(f" {adj_details_str}") current_unit_lines.append(f" {adj_details_str}")
else: else:
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else current_unit_lines.append(
current_unit_lines.append(' </AdjacentTerritories>') " Adjacent territories: None relevant or all are empty/uncontested by direct threats."
) # Added more descriptive else
current_unit_lines.append(" </AdjacentTerritories>")
current_unit_lines.append(' </UnitContext>') current_unit_lines.append(" </UnitContext>")
final_context_lines.extend(current_unit_lines) final_context_lines.extend(current_unit_lines)
final_context_lines.append("</PossibleOrdersContext>") final_context_lines.append("</PossibleOrdersContext>")
@ -469,7 +471,6 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
# Regex and tiny helpers # Regex and tiny helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
import re
from typing import Tuple, List, Dict, Optional, Any from typing import Tuple, List, Dict, Optional, Any
# ── order-syntax matchers ───────────────────────────────────────────────── # ── order-syntax matchers ─────────────────────────────────────────────────
@ -478,6 +479,7 @@ _HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$") _RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband _ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
def _is_hold_order(order: str) -> bool: # NEW def _is_hold_order(order: str) -> bool: # NEW
return bool(_HOLD_RE.match(order.strip())) return bool(_HOLD_RE.match(order.strip()))
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
"""Trim & uppercase for reliable comparisons.""" """Trim & uppercase for reliable comparisons."""
return name.strip().upper() return name.strip().upper()
def _is_simple_move(order: str) -> bool: def _is_simple_move(order: str) -> bool:
return bool(_SIMPLE_MOVE_RE.match(order.strip())) return bool(_SIMPLE_MOVE_RE.match(order.strip()))
def _is_retreat_order(order: str) -> bool: def _is_retreat_order(order: str) -> bool:
return bool(_RETREAT_RE.match(order.strip())) return bool(_RETREAT_RE.match(order.strip()))
def _is_adjust_order(order: str) -> bool: def _is_adjust_order(order: str) -> bool:
return bool(_ADJUST_RE.match(order.strip())) return bool(_ADJUST_RE.match(order.strip()))
def _split_move(order: str) -> Tuple[str, str]: def _split_move(order: str) -> Tuple[str, str]:
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only).""" """Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
unit_part, dest = order.split(" - ") unit_part, dest = order.split(" - ")
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
# Gather *all* friendly support orders for a given move # Gather *all* friendly support orders for a given move
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _all_support_examples( def _all_support_examples(
mover: str, mover: str,
dest: str, dest: str,
@ -526,6 +533,7 @@ def _all_support_examples(
return supports return supports
def _all_support_hold_examples( def _all_support_hold_examples(
holder: str, holder: str,
all_orders: Dict[str, List[str]], all_orders: Dict[str, List[str]],
@ -550,6 +558,7 @@ def _all_support_hold_examples(
# Province-type resolver (handles short codes, coasts, seas) # Province-type resolver (handles short codes, coasts, seas)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _province_type_display(game_map, prov_short: str) -> str: def _province_type_display(game_map, prov_short: str) -> str:
""" """
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code. Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
# Adjacent-territory lines (used by movement-phase builder) # Adjacent-territory lines (used by movement-phase builder)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _adjacent_territory_lines( def _adjacent_territory_lines(
graph, graph,
game_map, game_map,
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
pwr = unit_here.split(" ")[-1].strip("()") pwr = unit_here.split(" ")[-1].strip("()")
if pwr == our_power: if pwr == our_power:
friend_descr = unit_here.split(" (")[0] friend_descr = unit_here.split(" (")[0]
lines.append( lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
)
else: else:
lines.append( lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa")
f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa"
)
return lines return lines
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
# Movement-phase generator (UNCHANGED LOGIC) # Movement-phase generator (UNCHANGED LOGIC)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _generate_rich_order_context_movement( def _generate_rich_order_context_movement(
game: Any, game: Any,
power_name: str, power_name: str,
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
if _norm_power(unit_power) != me: if _norm_power(unit_power) != me:
continue # Skip units that arent ours continue # Skip units that arent ours
mover_descr, _ = _split_move( mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
)
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3] prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
prov_type_disp = _province_type_display(game_map, prov_short) prov_type_disp = _province_type_display(game_map, prov_short)
sc_tag = " (SC)" if prov_short in game_map.scs else "" sc_tag = " (SC)" if prov_short in game_map.scs else ""
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None" owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
owner_line = ( owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
)
ind = " " ind = " "
block: List[str] = [f"<Territory {prov_short}>"] block: List[str] = [f"<Territory {prov_short}>"]
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
# ----- Adjacent territories ----- # ----- Adjacent territories -----
block.append("# Adjacent territories:") block.append("# Adjacent territories:")
block.extend( block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
_adjacent_territory_lines(
graph, game_map, board_state,
unit_loc_full, mover_descr, power_name
)
)
# ----- Nearest enemy units ----- # ----- Nearest enemy units -----
block.append("# Nearest units (not ours):") block.append("# Nearest units (not ours):")
enemies = get_nearest_enemy_units( enemies = get_nearest_enemy_units(
board_state, graph, game_map, board_state,
power_name, unit_loc_full, graph,
game_map,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET", "ARMY" if mover_descr.startswith("A") else "FLEET",
n=3, n=3,
) )
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
# ----- Nearest uncontrolled SCs ----- # ----- Nearest uncontrolled SCs -----
block.append("# Nearest supply centers (not controlled by us):") block.append("# Nearest supply centers (not controlled by us):")
scs = get_nearest_uncontrolled_scs( scs = get_nearest_uncontrolled_scs(
game_map, board_state, graph, game_map,
power_name, unit_loc_full, board_state,
graph,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET", "ARMY" if mover_descr.startswith("A") else "FLEET",
n=3, n=3,
) )
@ -737,9 +741,7 @@ def _generate_rich_order_context_movement(
# ---- Moves (same behaviour as before) ---- # ---- Moves (same behaviour as before) ----
for mv in simple_moves: for mv in simple_moves:
mover, dest = _split_move(mv) mover, dest = _split_move(mv)
occ = _dest_occupancy_desc( occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
dest.split("/")[0][:3], game_map, board_state, power_name
)
block.append(f"{ind}{mv} {occ}") block.append(f"{ind}{mv} {occ}")
for s in _all_support_examples(mover, dest, possible_orders_for_power): for s in _all_support_examples(mover, dest, possible_orders_for_power):
@ -753,7 +755,6 @@ def _generate_rich_order_context_movement(
for s in _all_support_hold_examples(holder, possible_orders_for_power): for s in _all_support_hold_examples(holder, possible_orders_for_power):
block.append(f"{ind * 2}Available Support: {s}") block.append(f"{ind * 2}Available Support: {s}")
block.append(f"</Territory {prov_short}>") block.append(f"</Territory {prov_short}>")
blocks.append("\n".join(block)) blocks.append("\n".join(block))
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
# Retreat-phase builder echo orders verbatim, no tags # Retreat-phase builder echo orders verbatim, no tags
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _generate_rich_order_context_retreat( def _generate_rich_order_context_retreat(
game: Any, game: Any,
power_name: str, power_name: str,
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
return "\n".join(lines) if lines else "(No dislodged units)" return "\n".join(lines) if lines else "(No dislodged units)"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Adjustment-phase builder summary line + orders, no WAIVEs, no tags # Adjustment-phase builder summary line + orders, no WAIVEs, no tags
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _generate_rich_order_context_adjustment( def _generate_rich_order_context_adjustment(
game: Any, game: Any,
power_name: str, power_name: str,
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
return "\n".join(lines) if len(lines) > 1 else summary return "\n".join(lines) if len(lines) > 1 else summary
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Phase-dispatch wrapper (public entry point) # Phase-dispatch wrapper (public entry point)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def generate_rich_order_context( def generate_rich_order_context(
game: Any, game: Any,
power_name: str, power_name: str,
@ -847,21 +847,13 @@ def generate_rich_order_context(
phase_type = game.current_short_phase[-1] phase_type = game.current_short_phase[-1]
if phase_type == "M": # Movement if phase_type == "M": # Movement
return _generate_rich_order_context_movement( return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
game, power_name, possible_orders_for_power
)
if phase_type == "R": # Retreat if phase_type == "R": # Retreat
return _generate_rich_order_context_retreat( return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
game, power_name, possible_orders_for_power
)
if phase_type == "A": # Adjustment (build / disband) if phase_type == "A": # Adjustment (build / disband)
return _generate_rich_order_context_adjustment( return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
game, power_name, possible_orders_for_power
)
# Fallback treat unknown formats as movement # Fallback treat unknown formats as movement
return _generate_rich_order_context_movement( return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
game, power_name, possible_orders_for_power
)

View file

@ -11,7 +11,6 @@ from .possible_order_context import (
generate_rich_order_context, generate_rich_order_context,
generate_rich_order_context_xml, generate_rich_order_context_xml,
) )
import os
from .game_history import GameHistory # Assuming GameHistory is correctly importable from .game_history import GameHistory # Assuming GameHistory is correctly importable
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only # placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...") logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
return final_prompt return final_prompt

View file

@ -35,7 +35,7 @@ def atomic_write_json(data: dict, filepath: str):
# Write to a temporary file in the same directory # Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}" temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, 'w', encoding='utf-8') as f: with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination # Atomically rename the temporary file to the final destination
@ -110,11 +110,7 @@ def get_special_models() -> Dict[str, str]:
# Use a different Gemini model for formatting # Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash" export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
""" """
return { return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]: def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
@ -175,15 +171,19 @@ async def get_valid_orders(
# ── 2. Type check ────────────────────────────────────────── # ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list): if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1 model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []} return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ─────────────────── # ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = { CODE_TO_ENGINE = {
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE", "AUT": "AUSTRIA",
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY", "ENG": "ENGLAND",
"FRA": "FRANCE",
"GER": "GERMANY",
"ITA": "ITALY",
"RUS": "RUSSIA",
"TUR": "TURKEY",
} }
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name] engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
@ -226,8 +226,6 @@ async def get_valid_orders(
return {"valid": valid, "invalid": invalid} return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders( def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]], issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]], accepted_orders_dict: Dict[str, List[str]],
@ -348,8 +346,6 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
return "" return ""
# == New LLM Response Logging Function == # == New LLM Response Logging Function ==
def log_llm_response( def log_llm_response(
log_file_path: str, log_file_path: str,
@ -374,14 +370,18 @@ def log_llm_response(
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile: with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames # Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"] fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar='\\') # Use backslash for escaping escapechar="\\",
) # Use backslash for escaping
if not file_exists: if not file_exists:
writer.writeheader() # Write header only if file is new writer.writeheader() # Write header only if file is new
writer.writerow({ writer.writerow(
{
"model": model_name, "model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power "power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase, "phase": phase,
@ -389,12 +389,12 @@ def log_llm_response(
"raw_input": raw_input_prompt, # Added raw_input to the row "raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response, "raw_response": raw_response,
"success": success, "success": success,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True) logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
async def run_llm_and_log( async def run_llm_and_log(
client: "BaseModelClient", client: "BaseModelClient",
prompt: str, prompt: str,
@ -428,8 +428,7 @@ async def run_llm_and_log(
except Exception as e: except Exception as e:
if attempt == attempts - 1: if attempt == attempts - 1:
logger.error( logger.error(
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} " f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
f"in phase {phase}: {e}",
exc_info=True, exc_info=True,
) )
# Back-off before the next attempt (unless this was the last) # Back-off before the next attempt (unless this was the last)
@ -439,7 +438,6 @@ async def run_llm_and_log(
return raw_response return raw_response
# This generates a few lines of random alphanum chars to inject into the # This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation # system prompt. This lets us use temp=0 while still getting variation
# between trials. # between trials.
@ -447,15 +445,8 @@ async def run_llm_and_log(
# ensure valid json outputs. # ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80): def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters # Generate x lines of y random alphanumeric characters
seed_lines = [ seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
for _ in range(n_lines)
]
random_seed_block = (
"<RANDOM SEED PLEASE IGNORE>\n" +
"\n".join(seed_lines) +
"\n</RANDOM SEED>"
)
return random_seed_block return random_seed_block
@ -473,6 +464,7 @@ def get_prompt_path(prompt_name: str) -> str:
else: else:
return prompt_name return prompt_name
def normalize_recipient_name(recipient: str) -> str: def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations.""" """Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient: if not recipient:
@ -482,24 +474,24 @@ def normalize_recipient_name(recipient: str) -> str:
# Handle common LLM typos and abbreviations found in data # Handle common LLM typos and abbreviations found in data
name_mapping = { name_mapping = {
'EGMANY': 'GERMANY', "EGMANY": "GERMANY",
'GERMAN': 'GERMANY', "GERMAN": "GERMANY",
'UK': 'ENGLAND', "UK": "ENGLAND",
'BRIT': 'ENGLAND', "BRIT": "ENGLAND",
'ENGLAND': 'ENGLAND', # Keep as-is "ENGLAND": "ENGLAND", # Keep as-is
'FRANCE': 'FRANCE', # Keep as-is "FRANCE": "FRANCE", # Keep as-is
'GERMANY': 'GERMANY', # Keep as-is "GERMANY": "GERMANY", # Keep as-is
'ITALY': 'ITALY', # Keep as-is "ITALY": "ITALY", # Keep as-is
'AUSTRIA': 'AUSTRIA', # Keep as-is "AUSTRIA": "AUSTRIA", # Keep as-is
'RUSSIA': 'RUSSIA', # Keep as-is "RUSSIA": "RUSSIA", # Keep as-is
'TURKEY': 'TURKEY', # Keep as-is "TURKEY": "TURKEY", # Keep as-is
'Germany': 'GERMANY', "Germany": "GERMANY",
'England': 'ENGLAND', "England": "ENGLAND",
'France': 'FRANCE', "France": "FRANCE",
'Italy': 'ITALY', "Italy": "ITALY",
'Russia': 'RUSSIA', "Russia": "RUSSIA",
'Austria': 'AUSTRIA', "Austria": "AUSTRIA",
'Turkey': 'TURKEY', "Turkey": "TURKEY",
} }
normalized = name_mapping.get(recipient, recipient) normalized = name_mapping.get(recipient, recipient)