mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Ruff format and check
This commit is contained in:
parent
aadedd6512
commit
ad817695ef
13 changed files with 738 additions and 973 deletions
|
|
@ -10,7 +10,7 @@ import ast
|
|||
from ..config import config
|
||||
|
||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||
from .clients import BaseModelClient, load_model_client
|
||||
from .clients import BaseModelClient
|
||||
|
||||
# Import load_prompt and the new logging wrapper from utils
|
||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
||||
|
|
@ -336,7 +336,7 @@ class DiplomacyAgent:
|
|||
)
|
||||
# If even the last resort doesn't give a dict, return empty dict
|
||||
return {}
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
||||
return {}
|
||||
|
||||
|
|
@ -965,7 +965,7 @@ class DiplomacyAgent:
|
|||
other_powers = [p for p in game.powers if p != power_name]
|
||||
|
||||
# Create a readable board state string from the board_state dict
|
||||
board_state_str = f"Board State:\n"
|
||||
board_state_str = "Board State:\n"
|
||||
for p_name, power_data in board_state.get("powers", {}).items():
|
||||
# Get units and centers from the board state
|
||||
units = power_data.get("units", [])
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
import re
|
||||
import logging
|
||||
import ast # For literal_eval in JSON fallback parsing
|
||||
import aiohttp # For direct HTTP requests to Responses API
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
|
||||
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Use Async versions of clients
|
||||
|
|
@ -21,9 +20,9 @@ from together import AsyncTogether
|
|||
from together.error import APIError as TogetherAPIError # For specific error handling
|
||||
|
||||
from ..config import config
|
||||
from diplomacy.engine.message import GLOBAL
|
||||
from .game_history import GameHistory
|
||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
||||
|
||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
||||
# Moved formatter imports to avoid circular import - imported locally where needed
|
||||
|
|
@ -36,6 +35,7 @@ logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
|
|||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
##############################################################################
|
||||
# 1) Base Interface
|
||||
##############################################################################
|
||||
|
|
@ -118,24 +118,19 @@ class BaseModelClient:
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=phase,
|
||||
response_type='order', # Context for run_llm_and_log's own error logging
|
||||
temperature=0
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
|
||||
response_type="order", # Context for run_llm_and_log's own error logging
|
||||
temperature=0,
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
|
||||
|
||||
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
||||
if config.USE_UNFORMATTED_PROMPTS:
|
||||
# Local import to avoid circular dependency
|
||||
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
|
||||
|
||||
# Format the natural language response into structured format
|
||||
formatted_response = await format_with_gemini_flash(
|
||||
raw_response,
|
||||
FORMAT_ORDERS,
|
||||
power_name=power_name,
|
||||
phase=phase,
|
||||
log_file_path=log_file_path
|
||||
raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
|
||||
)
|
||||
else:
|
||||
# Use the raw response directly (already formatted)
|
||||
|
|
@ -145,9 +140,7 @@ class BaseModelClient:
|
|||
move_list = self._extract_moves(formatted_response, power_name)
|
||||
|
||||
if not move_list:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
|
||||
if model_error_stats is not None and self.model_name in model_error_stats:
|
||||
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
|
||||
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
||||
|
|
@ -195,7 +188,7 @@ class BaseModelClient:
|
|||
response_type="order_generation", # Specific type for CSV logging
|
||||
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
|
||||
raw_response=raw_response,
|
||||
success=success_status
|
||||
success=success_status,
|
||||
# token_usage and cost can be added later if available and if log_llm_response supports them
|
||||
)
|
||||
return parsed_orders_for_return
|
||||
|
|
@ -216,9 +209,7 @@ class BaseModelClient:
|
|||
|
||||
if not matches:
|
||||
# Some LLMs might not put the colon or might have triple backtick fences.
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
|
||||
|
||||
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
||||
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
||||
|
|
@ -226,47 +217,35 @@ class BaseModelClient:
|
|||
|
||||
if not matches:
|
||||
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
|
||||
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
|
||||
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
|
||||
|
||||
if not matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
|
||||
|
||||
# 2) If still no match, check for triple-backtick code fences containing JSON
|
||||
if not matches:
|
||||
code_fence_pattern = r"```json\n(.*?)\n```"
|
||||
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
||||
if matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
|
||||
|
||||
# 2b) Also try plain ``` code fences without json marker
|
||||
if not matches:
|
||||
code_fence_plain = r"```\n(.*?)\n```"
|
||||
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
|
||||
if matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
|
||||
|
||||
# 2c) Try to find bare JSON object anywhere in the response
|
||||
if not matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
|
||||
# Look for a JSON object that contains "orders" key
|
||||
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
|
||||
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
|
||||
if matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
|
||||
|
||||
# 3) Attempt to parse JSON if we found anything
|
||||
json_text = None
|
||||
|
|
@ -283,9 +262,7 @@ class BaseModelClient:
|
|||
json_text = json_text.strip()
|
||||
|
||||
if not json_text:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
|
||||
return None
|
||||
|
||||
# 3a) Try JSON loading
|
||||
|
|
@ -293,14 +270,12 @@ class BaseModelClient:
|
|||
data = json.loads(json_text)
|
||||
return data.get("orders", None)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
|
||||
|
||||
# Try to fix common JSON issues
|
||||
try:
|
||||
# Remove trailing commas
|
||||
fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text)
|
||||
fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
|
||||
# Fix single quotes to double quotes
|
||||
fixed_json = fixed_json.replace("'", '"')
|
||||
# Try parsing again
|
||||
|
|
@ -308,14 +283,12 @@ class BaseModelClient:
|
|||
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
|
||||
return data.get("orders", None)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
|
||||
|
||||
# Try to remove inline comments (// style)
|
||||
try:
|
||||
# Remove // comments from each line
|
||||
lines = json_text.split('\n')
|
||||
lines = json_text.split("\n")
|
||||
cleaned_lines = []
|
||||
for line in lines:
|
||||
# Find // that's not inside quotes
|
||||
|
|
@ -326,12 +299,12 @@ class BaseModelClient:
|
|||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if char == '\\':
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
in_quotes = not in_quotes
|
||||
if not in_quotes and line[i:i+2] == '//':
|
||||
if not in_quotes and line[i : i + 2] == "//":
|
||||
comment_pos = i
|
||||
break
|
||||
|
||||
|
|
@ -342,17 +315,15 @@ class BaseModelClient:
|
|||
cleaned_line = line
|
||||
cleaned_lines.append(cleaned_line)
|
||||
|
||||
comment_free_json = '\n'.join(cleaned_lines)
|
||||
comment_free_json = "\n".join(cleaned_lines)
|
||||
# Also remove trailing commas after comment removal
|
||||
comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json)
|
||||
comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
|
||||
|
||||
data = json.loads(comment_free_json)
|
||||
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
|
||||
return data.get("orders", None)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
|
||||
|
||||
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
||||
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
||||
|
|
@ -366,16 +337,12 @@ class BaseModelClient:
|
|||
if isinstance(moves, list):
|
||||
return moves
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
|
||||
|
||||
# If all attempts failed
|
||||
return None
|
||||
|
||||
def _validate_orders(
|
||||
self, moves: List[str], possible_orders: Dict[str, List[str]]
|
||||
) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
||||
def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
||||
"""
|
||||
Filter out invalid moves, fill missing with HOLD, else fallback.
|
||||
Returns a tuple: (validated_moves, invalid_moves_found)
|
||||
|
|
@ -405,15 +372,15 @@ class BaseModelClient:
|
|||
for loc, orders_list in possible_orders.items():
|
||||
if loc not in used_locs and orders_list:
|
||||
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
||||
validated.append(
|
||||
hold_candidates[0] if hold_candidates else orders_list[0]
|
||||
)
|
||||
validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
|
||||
|
||||
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
|
||||
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
|
||||
return self.fallback_orders(possible_orders), []
|
||||
elif not validated and invalid_moves_found: # All LLM moves were invalid
|
||||
logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}")
|
||||
logger.warning(
|
||||
f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
|
||||
)
|
||||
# We return empty list for validated, but the invalid_moves_found list is populated
|
||||
return self.fallback_orders(possible_orders), invalid_moves_found
|
||||
|
||||
|
|
@ -444,7 +411,6 @@ class BaseModelClient:
|
|||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary_str: Optional[str] = None, # Added
|
||||
) -> str:
|
||||
|
||||
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||
|
||||
context = self.build_context_prompt(
|
||||
|
|
@ -508,7 +474,15 @@ class BaseModelClient:
|
|||
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
||||
|
||||
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
||||
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
|
||||
final_prompt = (
|
||||
final_prompt.replace("AUSTRIA", "Austria")
|
||||
.replace("ENGLAND", "England")
|
||||
.replace("FRANCE", "France")
|
||||
.replace("GERMANY", "Germany")
|
||||
.replace("ITALY", "Italy")
|
||||
.replace("RUSSIA", "Russia")
|
||||
.replace("TURKEY", "Turkey")
|
||||
)
|
||||
return final_prompt
|
||||
|
||||
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
||||
|
|
@ -524,7 +498,6 @@ class BaseModelClient:
|
|||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary_str: Optional[str] = None, # Added
|
||||
) -> str:
|
||||
|
||||
prompt = self.build_planning_prompt(
|
||||
game,
|
||||
board_state,
|
||||
|
|
@ -545,7 +518,7 @@ class BaseModelClient:
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game_phase, # Use game_phase for logging
|
||||
response_type='plan_reply', # Changed from 'plan' to avoid confusion
|
||||
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
|
||||
return raw_response
|
||||
|
|
@ -592,7 +565,7 @@ class BaseModelClient:
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game_phase,
|
||||
response_type='negotiation', # For run_llm_and_log's internal context
|
||||
response_type="negotiation", # For run_llm_and_log's internal context
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
||||
|
||||
|
|
@ -600,13 +573,10 @@ class BaseModelClient:
|
|||
if config.USE_UNFORMATTED_PROMPTS:
|
||||
# Local import to avoid circular dependency
|
||||
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
|
||||
|
||||
# Format the natural language response into structured JSON
|
||||
formatted_response = await format_with_gemini_flash(
|
||||
raw_response,
|
||||
FORMAT_CONVERSATION,
|
||||
power_name=power_name,
|
||||
phase=game_phase,
|
||||
log_file_path=log_file_path
|
||||
raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
|
||||
)
|
||||
else:
|
||||
# Use the raw response directly (already formatted)
|
||||
|
|
@ -632,10 +602,10 @@ class BaseModelClient:
|
|||
# Original parsing logic as fallback
|
||||
if not parsed_messages:
|
||||
# Attempt to find blocks enclosed in {{...}}
|
||||
double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL)
|
||||
double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
|
||||
if double_brace_blocks:
|
||||
# If {{...}} blocks are found, assume each is a self-contained JSON object
|
||||
json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks])
|
||||
json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
|
||||
else:
|
||||
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
|
||||
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
|
||||
|
|
@ -650,10 +620,10 @@ class BaseModelClient:
|
|||
json_blocks = [json.dumps(data)]
|
||||
except json.JSONDecodeError:
|
||||
# If parsing the whole block fails, fall back to regex for individual objects
|
||||
json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL)
|
||||
json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
|
||||
else:
|
||||
# If no markdown block, fall back to regex for any JSON object in the response
|
||||
json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL)
|
||||
json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
|
||||
|
||||
# Process json_blocks if we have them from fallback parsing
|
||||
if not parsed_messages and json_blocks:
|
||||
|
|
@ -661,7 +631,7 @@ class BaseModelClient:
|
|||
try:
|
||||
cleaned_block = block.strip()
|
||||
# Attempt to fix common JSON issues like trailing commas before parsing
|
||||
cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block)
|
||||
cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
|
||||
parsed_message = json.loads(cleaned_block)
|
||||
parsed_messages.append(parsed_message)
|
||||
except json.JSONDecodeError as e:
|
||||
|
|
@ -710,7 +680,7 @@ class BaseModelClient:
|
|||
response_type="negotiation_message",
|
||||
raw_input_prompt=raw_input_prompt,
|
||||
raw_response=raw_response,
|
||||
success=success_status
|
||||
success=success_status,
|
||||
)
|
||||
return messages_to_return
|
||||
|
||||
|
|
@ -770,7 +740,7 @@ class BaseModelClient:
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game.current_short_phase,
|
||||
response_type='plan_generation', # More specific type for run_llm_and_log context
|
||||
response_type="plan_generation", # More specific type for run_llm_and_log context
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
|
||||
# No parsing needed for the plan, return the raw string
|
||||
|
|
@ -790,7 +760,7 @@ class BaseModelClient:
|
|||
response_type="plan_generation", # Specific type for CSV logging
|
||||
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
|
||||
raw_response=raw_plan_response,
|
||||
success=success_status
|
||||
success=success_status,
|
||||
# token_usage and cost can be added later
|
||||
)
|
||||
return plan_to_return
|
||||
|
|
@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
|
|||
):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
|
||||
self.base_url = (
|
||||
base_url
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
|
||||
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
|
|
@ -832,11 +798,7 @@ class OpenAIClient(BaseModelClient):
|
|||
inject_random_seed: bool = True,
|
||||
) -> str:
|
||||
try:
|
||||
system_prompt_content = (
|
||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
||||
if inject_random_seed
|
||||
else self.system_prompt
|
||||
)
|
||||
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
|
|
@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
|
|||
temperature=temperature,
|
||||
)
|
||||
if not response.content:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
|
||||
return ""
|
||||
return response.content[0].text.strip() if response.content else ""
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(
|
||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
|
|||
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
try:
|
||||
generation_config = genai.types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
max_output_tokens=self.max_tokens
|
||||
)
|
||||
generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
|
||||
response = await self.client.generate_content_async(
|
||||
contents=full_prompt,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
if not response or not response.text:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
|
||||
return ""
|
||||
return response.text.strip()
|
||||
except Exception as e:
|
||||
|
|
@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
|
|||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||
self.client = AsyncDeepSeekOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://api.deepseek.com/"
|
||||
)
|
||||
self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
try:
|
||||
|
|
@ -986,9 +934,7 @@ class DeepSeekClient(BaseModelClient):
|
|||
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] No valid response in generate_response."
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] No valid response in generate_response.")
|
||||
return ""
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
|
|
@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
|
|||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -1037,19 +981,14 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
# Make the API call using aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"[{self.model_name}] API error (status {response.status}): {error_text}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
|
||||
return ""
|
||||
|
||||
response_data = await response.json()
|
||||
|
|
@ -1059,24 +998,18 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
try:
|
||||
outputs = response_data.get("output", [])
|
||||
if len(outputs) < 2:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
|
||||
return ""
|
||||
|
||||
# The message is typically in the second output item
|
||||
message_output = outputs[1]
|
||||
if message_output.get("type") != "message":
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
|
||||
return ""
|
||||
|
||||
content_list = message_output.get("content", [])
|
||||
if not content_list:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty content list in message output"
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] Empty content list in message output")
|
||||
return ""
|
||||
|
||||
# Look for the content item with type 'output_text'
|
||||
|
|
@ -1087,28 +1020,20 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
break
|
||||
|
||||
if not text_content:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
|
||||
)
|
||||
logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
|
||||
return ""
|
||||
|
||||
return text_content.strip()
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
|
||||
return ""
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] HTTP client error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -1129,10 +1054,7 @@ class OpenRouterClient(BaseModelClient):
|
|||
if not self.api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=self.api_key
|
||||
)
|
||||
self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
|
||||
|
||||
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
|
||||
|
||||
|
|
@ -1150,10 +1072,7 @@ class OpenRouterClient(BaseModelClient):
|
|||
# Prepare standard OpenAI-compatible request
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta}
|
||||
],
|
||||
messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
|
@ -1245,7 +1164,7 @@ class TogetherAIClient(BaseModelClient):
|
|||
##############################################################################
|
||||
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
||||
##############################################################################
|
||||
import requests, asyncio
|
||||
|
||||
|
||||
class RequestsOpenAIClient(BaseModelClient):
|
||||
"""
|
||||
|
|
@ -1266,11 +1185,7 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
||||
|
||||
self.base_url = (
|
||||
base_url
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
).rstrip("/")
|
||||
self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
|
||||
|
||||
self.endpoint = f"{self.base_url}/chat/completions"
|
||||
|
||||
|
|
@ -1291,11 +1206,7 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
temperature: float = 0.0,
|
||||
inject_random_seed: bool = True,
|
||||
) -> str:
|
||||
system_prompt_content = (
|
||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
||||
if inject_random_seed
|
||||
else self.system_prompt
|
||||
)
|
||||
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
|
|
@ -1322,7 +1233,6 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
return ""
|
||||
|
||||
|
||||
|
||||
##############################################################################
|
||||
# 3) Factory to Load Model Client
|
||||
##############################################################################
|
||||
|
|
@ -1332,6 +1242,7 @@ class ModelSpec(NamedTuple):
|
|||
base: Optional[str] # 'https://proxy.foo'
|
||||
key: Optional[str] # 'sk-…' (may be None)
|
||||
|
||||
|
||||
def _parse_model_spec(raw: str) -> ModelSpec:
|
||||
"""
|
||||
Splits once on '#' (API key) and once on '@' (base URL). A leading
|
||||
|
|
@ -1441,8 +1352,6 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
|||
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
|
||||
|
||||
|
||||
##############################################################################
|
||||
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
||||
##############################################################################
|
||||
|
|
@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
|||
visible = []
|
||||
for msg in conversation_messages:
|
||||
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
||||
if (
|
||||
msg["recipient"] == "ALL"
|
||||
or msg["recipient"] == "GLOBAL"
|
||||
or msg["sender"] == power_name
|
||||
or msg["recipient"] == power_name
|
||||
):
|
||||
if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
|
||||
visible.append(msg)
|
||||
return visible # already in chronological order if appended that way
|
||||
|
|
|
|||
|
|
@ -11,25 +11,28 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||
import os
|
||||
|
||||
try:
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_diary_consolidation(
|
||||
agent: 'DiplomacyAgent',
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
log_file_path: str,
|
||||
entries_to_keep_unsummarized: int = 6,
|
||||
|
|
@ -39,70 +42,42 @@ async def run_diary_consolidation(
|
|||
Consolidate older diary entries while keeping recent ones.
|
||||
This is the logic moved from the DiplomacyAgent class.
|
||||
"""
|
||||
logger.info(
|
||||
f"[{agent.power_name}] CONSOLIDATION START — "
|
||||
f"{len(agent.full_private_diary)} total full entries"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
||||
|
||||
full_entries = [
|
||||
e for e in agent.full_private_diary
|
||||
if not e.startswith("[CONSOLIDATED HISTORY]")
|
||||
]
|
||||
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
||||
|
||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.info(
|
||||
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
||||
"skipping consolidation"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
||||
return
|
||||
|
||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||
if not match:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
||||
"aborting consolidation"
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
return
|
||||
|
||||
cutoff_year = int(match.group(1))
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
|
||||
|
||||
def _entry_year(entry: str) -> int | None:
|
||||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
entries_to_summarize = [
|
||||
e for e in full_entries
|
||||
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
||||
]
|
||||
entries_to_keep = [
|
||||
e for e in full_entries
|
||||
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
||||
]
|
||||
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
||||
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
||||
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
||||
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
||||
|
||||
if not entries_to_summarize:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.warning(
|
||||
f"[{agent.power_name}] No eligible entries to summarise; "
|
||||
"context diary left unchanged"
|
||||
)
|
||||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||
return
|
||||
|
||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||
return
|
||||
|
||||
prompt = prompt_template.format(
|
||||
|
|
@ -132,23 +107,14 @@ async def run_diary_consolidation(
|
|||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||
success_flag = "TRUE"
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Consolidation complete — "
|
||||
f"{len(agent.private_diary)} context entries now"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
||||
finally:
|
||||
log_llm_response(
|
||||
log_file_path=log_file_path,
|
||||
model_name=(
|
||||
consolidation_client.model_name
|
||||
if consolidation_client is not None
|
||||
else agent.client.model_name
|
||||
),
|
||||
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
||||
power_name=agent.power_name,
|
||||
phase=game.current_short_phase,
|
||||
response_type="diary_consolidation",
|
||||
|
|
|
|||
|
|
@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
|
|||
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
# Import logging function and model configuration
|
||||
from .utils import log_llm_response, get_special_models
|
||||
|
||||
# Import client loading function
|
||||
from .clients import load_model_client
|
||||
|
||||
|
|
@ -26,11 +25,7 @@ FORMAT_ORDER_DIARY = "order_diary"
|
|||
|
||||
|
||||
async def format_with_gemini_flash(
|
||||
raw_response: str,
|
||||
format_type: str,
|
||||
power_name: Optional[str] = None,
|
||||
phase: Optional[str] = None,
|
||||
log_file_path: Optional[str] = None
|
||||
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Format a natural language response into required JSON structure using Gemini Flash.
|
||||
|
|
@ -52,7 +47,7 @@ async def format_with_gemini_flash(
|
|||
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
||||
FORMAT_ORDERS: "formatting/format_orders.txt",
|
||||
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
|
||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
|
||||
}
|
||||
|
||||
if format_type not in format_prompts:
|
||||
|
|
@ -63,7 +58,7 @@ async def format_with_gemini_flash(
|
|||
if not prompt_file.exists():
|
||||
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
with open(prompt_file, "r") as f:
|
||||
format_prompt = f.read()
|
||||
|
||||
# Replace placeholder with actual response
|
||||
|
|
@ -87,7 +82,7 @@ async def format_with_gemini_flash(
|
|||
formatted_response = await formatter_client.generate_response(
|
||||
prompt=format_prompt,
|
||||
temperature=0, # Deterministic formatting
|
||||
inject_random_seed=False # No need for random seed in formatting
|
||||
inject_random_seed=False, # No need for random seed in formatting
|
||||
)
|
||||
|
||||
if not formatted_response:
|
||||
|
|
@ -116,7 +111,7 @@ async def format_with_gemini_flash(
|
|||
response_type=f"format_{format_type}",
|
||||
raw_input_prompt=format_prompt,
|
||||
raw_response=formatted_response,
|
||||
success="Success"
|
||||
success="Success",
|
||||
)
|
||||
|
||||
return formatted_response
|
||||
|
|
@ -134,7 +129,7 @@ async def format_with_gemini_flash(
|
|||
response_type=f"format_{format_type}",
|
||||
raw_input_prompt=format_prompt,
|
||||
raw_response=f"ERROR: {str(e)}",
|
||||
success=f"Failure: {type(e).__name__}"
|
||||
success=f"Failure: {type(e).__name__}",
|
||||
)
|
||||
|
||||
# Return empty structure based on format type
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from pprint import pformat
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -24,12 +23,8 @@ class Phase:
|
|||
plans: Dict[str, str] = field(default_factory=dict)
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
||||
orders_by_power: Dict[str, List[str]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
results_by_power: Dict[str, List[List[str]]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
|
||||
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
|
||||
# NEW: Store phase-end summaries provided by each power
|
||||
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
||||
# NEW: Store experience/journal updates from each power for this phase
|
||||
|
|
@ -39,9 +34,7 @@ class Phase:
|
|||
self.plans[power_name] = plan
|
||||
|
||||
def add_message(self, sender: str, recipient: str, content: str):
|
||||
self.messages.append(
|
||||
Message(sender=sender, recipient=recipient, content=content)
|
||||
)
|
||||
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
|
||||
|
||||
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
||||
self.orders_by_power[power].extend(orders)
|
||||
|
|
@ -110,14 +103,10 @@ class GameHistory:
|
|||
phase.plans[power_name] = plan
|
||||
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
||||
|
||||
def add_message(
|
||||
self, phase_name: str, sender: str, recipient: str, message_content: str
|
||||
):
|
||||
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
|
||||
phase = self._get_phase(phase_name)
|
||||
if phase:
|
||||
message = Message(
|
||||
sender=sender, recipient=recipient, content=message_content
|
||||
)
|
||||
message = Message(sender=sender, recipient=recipient, content=message_content)
|
||||
phase.messages.append(message)
|
||||
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
||||
|
||||
|
|
@ -153,10 +142,6 @@ class GameHistory:
|
|||
return {}
|
||||
return self.phases[-1].plans
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_order_history_for_prompt(
|
||||
self,
|
||||
game: "Game",
|
||||
|
|
@ -194,10 +179,7 @@ class GameHistory:
|
|||
return str(tag).strip().lower()
|
||||
|
||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||
eng2code = {
|
||||
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
|
||||
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
|
||||
}
|
||||
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
|
||||
norm = game.map.norm
|
||||
|
||||
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
||||
|
|
@ -264,13 +246,6 @@ class GameHistory:
|
|||
return "\n(No orders were issued in recent history)\n"
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||
current_phase: Optional[Phase] = None
|
||||
for phase_obj in self.phases:
|
||||
|
|
@ -325,11 +300,7 @@ class GameHistory:
|
|||
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
||||
# Skip if sender is this power (don't need to respond to own messages)
|
||||
if msg.sender != power_name:
|
||||
messages_to_power.append({
|
||||
'sender': msg.sender,
|
||||
'content': msg.content,
|
||||
'phase': phase.name
|
||||
})
|
||||
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
|
||||
|
||||
# Add debug logging
|
||||
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
||||
|
|
@ -361,10 +332,10 @@ class GameHistory:
|
|||
for msg in phase.messages:
|
||||
# Handle both Message objects and dict objects
|
||||
if isinstance(msg, Message):
|
||||
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
|
||||
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
|
||||
sender_messages.append(msg)
|
||||
else: # Assume dict
|
||||
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
|
||||
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
|
||||
sender_messages.append(msg)
|
||||
|
||||
# Check for responses in this and next phases
|
||||
|
|
@ -374,8 +345,8 @@ class GameHistory:
|
|||
recipient = msg.recipient
|
||||
msg_content = msg.content
|
||||
else:
|
||||
recipient = msg['recipient']
|
||||
msg_content = msg['content']
|
||||
recipient = msg["recipient"]
|
||||
msg_content = msg["content"]
|
||||
|
||||
# Look for responses in current phase and next phases
|
||||
found_response = False
|
||||
|
|
@ -387,12 +358,14 @@ class GameHistory:
|
|||
response_msgs = []
|
||||
for m in check_phase.messages:
|
||||
if isinstance(m, Message):
|
||||
if m.sender == recipient and (m.recipient == sender_name or
|
||||
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
|
||||
if m.sender == recipient and (
|
||||
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
|
||||
):
|
||||
response_msgs.append(m)
|
||||
else: # Assume dict
|
||||
if m['sender'] == recipient and (m['recipient'] == sender_name or
|
||||
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
|
||||
if m["sender"] == recipient and (
|
||||
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
|
||||
):
|
||||
response_msgs.append(m)
|
||||
|
||||
if response_msgs:
|
||||
|
|
@ -402,9 +375,6 @@ class GameHistory:
|
|||
if not found_response:
|
||||
if recipient not in ignored_by_power:
|
||||
ignored_by_power[recipient] = []
|
||||
ignored_by_power[recipient].append({
|
||||
'phase': phase.name,
|
||||
'content': msg_content
|
||||
})
|
||||
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
|
||||
|
||||
return ignored_by_power
|
||||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
from argparse import Namespace
|
||||
|
||||
from diplomacy import Game
|
||||
|
|
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# --- Serialization / Deserialization ---
|
||||
|
||||
|
||||
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||
"""Converts an agent object to a JSON-serializable dictionary."""
|
||||
return {
|
||||
|
|
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|||
"private_diary": agent.private_diary,
|
||||
}
|
||||
|
||||
def deserialize_agent(
|
||||
agent_data: dict,
|
||||
prompts_dir: Optional[str] = None,
|
||||
*,
|
||||
override_model_id: Optional[str] = None
|
||||
) -> DiplomacyAgent:
|
||||
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
||||
"""
|
||||
Recreates an agent object from a dictionary.
|
||||
|
||||
|
|
@ -70,6 +67,7 @@ def deserialize_agent(
|
|||
# game_logic.py
|
||||
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
||||
|
||||
|
||||
def _next_phase_name(short: str) -> str:
|
||||
"""
|
||||
Return the Diplomacy phase string that chronologically follows *short*.
|
||||
|
|
@ -89,13 +87,9 @@ def _next_phase_name(short: str) -> str:
|
|||
else: # winter→ spring, next year
|
||||
return f"S{year + 1}M"
|
||||
|
||||
|
||||
def save_game_state(
|
||||
game: Game,
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: GameHistory,
|
||||
output_path: str,
|
||||
run_config: Namespace,
|
||||
completed_phase_name: str
|
||||
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
||||
):
|
||||
"""
|
||||
Serialise the entire game to JSON, preserving per-phase custom metadata
|
||||
|
|
@ -130,9 +124,7 @@ def save_game_state(
|
|||
if extras:
|
||||
previous_phase_extras[phase["name"]] = extras
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Could not load previous save to retain metadata: %s", exc, exc_info=True
|
||||
)
|
||||
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 2. Build the fresh base structure from the diplomacy library. #
|
||||
|
|
@ -143,11 +135,7 @@ def save_game_state(
|
|||
# 3. Walk every phase and merge the metadata back in. #
|
||||
# -------------------------------------------------------------- #
|
||||
# Capture the *current* snapshot of every live agent exactly once.
|
||||
current_state_agents = {
|
||||
p_name: serialize_agent(p_agent)
|
||||
for p_name, p_agent in agents.items()
|
||||
if not game.powers[p_name].is_eliminated()
|
||||
}
|
||||
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
||||
|
||||
for phase_block in saved_game.get("phases", []):
|
||||
if int(phase_block["name"][1:5]) > run_config.max_year:
|
||||
|
|
@ -167,19 +155,13 @@ def save_game_state(
|
|||
|
||||
# Plans for this phase – may be empty in non-movement phases.
|
||||
phase_obj = game_history._get_phase(phase_name)
|
||||
phase_block["state_history_plans"] = (
|
||||
phase_obj.plans if phase_obj else {}
|
||||
)
|
||||
|
||||
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 4. Attach top-level metadata and write atomically. #
|
||||
# -------------------------------------------------------------- #
|
||||
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
||||
saved_game["final_agent_states"] = {
|
||||
p_name: {"relationships": a.relationships, "goals": a.goals}
|
||||
for p_name, a in agents.items()
|
||||
}
|
||||
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
||||
|
||||
# Filter out phases > max_year
|
||||
# saved_game["phases"] = [
|
||||
|
|
@ -191,7 +173,6 @@ def save_game_state(
|
|||
logger.info("Game state saved successfully.")
|
||||
|
||||
|
||||
|
||||
def load_game_state(
|
||||
run_dir: str,
|
||||
game_file_name: str,
|
||||
|
|
@ -204,7 +185,7 @@ def load_game_state(
|
|||
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
||||
|
||||
logger.info(f"Loading game state from: {game_file_path}")
|
||||
with open(game_file_path, 'r') as f:
|
||||
with open(game_file_path, "r") as f:
|
||||
saved_game_data = json.load(f)
|
||||
|
||||
# If resuming, find the specified phase and truncate the data after it
|
||||
|
|
@ -213,37 +194,37 @@ def load_game_state(
|
|||
try:
|
||||
# Find the index of the phase *before* the one we want to resume from.
|
||||
# We will start the simulation *at* the resume_from_phase.
|
||||
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
|
||||
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
||||
# Truncate the list to exclude everything after the resume phase
|
||||
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
||||
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
|
||||
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
||||
|
||||
# Wipe any data that must be regenerated.
|
||||
for key in ("orders", "results", "messages"):
|
||||
saved_game_data['phases'][-1].pop(key, None)
|
||||
saved_game_data["phases"][-1].pop(key, None)
|
||||
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
||||
except StopIteration:
|
||||
# If the phase is not found, maybe it's the first phase (S1901M)
|
||||
if resume_from_phase == "S1901M":
|
||||
saved_game_data['phases'] = []
|
||||
saved_game_data["phases"] = []
|
||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||
else:
|
||||
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
||||
|
||||
# Reconstruct the Game object
|
||||
last_phase = saved_game_data['phases'][-1]
|
||||
last_phase = saved_game_data["phases"][-1]
|
||||
|
||||
# Wipe the data that must be regenerated **but preserve the keys**
|
||||
last_phase['orders'] = {} # was dict
|
||||
last_phase['results'] = {} # was dict
|
||||
last_phase['messages'] = []
|
||||
last_phase["orders"] = {} # was dict
|
||||
last_phase["results"] = {} # was dict
|
||||
last_phase["messages"] = []
|
||||
|
||||
game = from_saved_game_format(saved_game_data)
|
||||
|
||||
game.phase_summaries = saved_game_data.get('phase_summaries', {})
|
||||
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
||||
|
||||
# Reconstruct agents and game history from the *last* valid phase in the data
|
||||
if not saved_game_data['phases']:
|
||||
if not saved_game_data["phases"]:
|
||||
# This happens if we are resuming from the very beginning (S1901M)
|
||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||
agents = {} # Will be created by the main loop
|
||||
|
|
@ -251,10 +232,10 @@ def load_game_state(
|
|||
else:
|
||||
# We save the game state up to & including the current (uncompleted) phase.
|
||||
# So we need to grab the agent state from the previous (completed) phase.
|
||||
if len(saved_game_data['phases']) <= 1:
|
||||
if len(saved_game_data["phases"]) <= 1:
|
||||
last_phase_data = {}
|
||||
else:
|
||||
last_phase_data = saved_game_data['phases'][-2]
|
||||
last_phase_data = saved_game_data["phases"][-2]
|
||||
|
||||
# -------------------- Rebuild agents -------------------- #
|
||||
agents = {}
|
||||
|
|
@ -269,18 +250,11 @@ def load_game_state(
|
|||
elif len(provided) == 1:
|
||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --models argument: expected 1 or {len(powers_order)} "
|
||||
f"items, got {len(provided)}."
|
||||
)
|
||||
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||
|
||||
if "state_agents" in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir
|
||||
if run_config and hasattr(run_config, "prompts_dir")
|
||||
else None
|
||||
)
|
||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
agents[power_name] = deserialize_agent(
|
||||
|
|
@ -290,37 +264,27 @@ def load_game_state(
|
|||
)
|
||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot resume: 'state_agents' key not found in the last phase "
|
||||
"of the save file."
|
||||
)
|
||||
|
||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||
|
||||
# Rebuild GameHistory
|
||||
game_history = GameHistory()
|
||||
logger.info("Rebuilding game history...")
|
||||
for phase_data in saved_game_data['phases'][:-1]:
|
||||
phase_name = phase_data['name']
|
||||
for phase_data in saved_game_data["phases"][:-1]:
|
||||
phase_name = phase_data["name"]
|
||||
game_history.add_phase(phase_name)
|
||||
# Add messages
|
||||
for msg in phase_data.get('messages', []):
|
||||
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
|
||||
for msg in phase_data.get("messages", []):
|
||||
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||
# Add plans
|
||||
if 'state_history_plans' in phase_data:
|
||||
for p_name, plan in phase_data['state_history_plans'].items():
|
||||
if "state_history_plans" in phase_data:
|
||||
for p_name, plan in phase_data["state_history_plans"].items():
|
||||
game_history.add_plan(phase_name, p_name, plan)
|
||||
logger.info("Game history rebuilt.")
|
||||
|
||||
|
||||
return game, agents, game_history, run_config
|
||||
|
||||
|
||||
async def initialize_new_game(
|
||||
args: Namespace,
|
||||
game: Game,
|
||||
game_history: GameHistory,
|
||||
llm_log_file_path: str
|
||||
) -> Dict[str, DiplomacyAgent]:
|
||||
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
||||
"""Initializes agents for a new game."""
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# ai_diplomacy/initialization.py
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from ..config import config
|
||||
|
||||
|
|
@ -18,10 +17,11 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_agent_state_ext(
|
||||
agent: 'DiplomacyAgent',
|
||||
game: 'Game',
|
||||
game_history: 'GameHistory',
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
game_history: "GameHistory",
|
||||
log_file_path: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
):
|
||||
|
|
@ -37,21 +37,20 @@ async def initialize_agent_state_ext(
|
|||
try:
|
||||
# Load the prompt template
|
||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
|
||||
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||
if not initial_prompt_template:
|
||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
||||
return
|
||||
|
||||
# Format the prompt with variables
|
||||
initial_prompt = initial_prompt_template.format(
|
||||
power_name=power_name,
|
||||
allowed_labels_str=allowed_labels_str
|
||||
)
|
||||
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||
|
||||
board_state = game.get_state() if game else {}
|
||||
possible_orders = game.get_all_possible_orders() if game else {}
|
||||
|
||||
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
|
||||
logger.debug(
|
||||
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
|
||||
)
|
||||
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
||||
# For initialization, game should always be present.
|
||||
|
||||
|
|
@ -76,7 +75,7 @@ async def initialize_agent_state_ext(
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=current_phase,
|
||||
response_type='initialization', # Context for run_llm_and_log internal error logging
|
||||
response_type="initialization", # Context for run_llm_and_log internal error logging
|
||||
)
|
||||
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
||||
|
||||
|
|
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
|
|||
if config.USE_UNFORMATTED_PROMPTS:
|
||||
# Format the natural language response into JSON
|
||||
formatted_response = await format_with_gemini_flash(
|
||||
response,
|
||||
FORMAT_INITIAL_STATE,
|
||||
power_name=power_name,
|
||||
phase=current_phase,
|
||||
log_file_path=log_file_path
|
||||
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
|
||||
)
|
||||
else:
|
||||
# Use the raw response directly (already formatted)
|
||||
|
|
@ -108,12 +103,16 @@ async def initialize_agent_state_ext(
|
|||
# Defensive check for update_data type if parsing was initially considered successful
|
||||
if parsed_successfully:
|
||||
if isinstance(update_data, str):
|
||||
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
|
||||
logger.error(
|
||||
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
|
||||
)
|
||||
update_data = {} # Treat as parsing failure
|
||||
parsed_successfully = False
|
||||
success_status = "Failure: ParsedAsStr"
|
||||
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
||||
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
|
||||
logger.error(
|
||||
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
|
||||
)
|
||||
update_data = {} # Treat as parsing failure
|
||||
parsed_successfully = False
|
||||
success_status = "Failure: NotADict"
|
||||
|
|
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
|
|||
initial_relationships_applied = False
|
||||
|
||||
if parsed_successfully:
|
||||
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
|
||||
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
|
||||
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
|
||||
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
|
||||
|
||||
if isinstance(initial_goals, list) and initial_goals:
|
||||
agent.goals = initial_goals
|
||||
|
|
@ -200,7 +199,7 @@ async def initialize_agent_state_ext(
|
|||
response_type="initial_state_setup", # Specific type for CSV logging
|
||||
raw_input_prompt=full_prompt,
|
||||
raw_response=response,
|
||||
success=success_status
|
||||
success=success_status,
|
||||
)
|
||||
|
||||
# Final log of state after initialization attempt
|
||||
|
|
|
|||
|
|
@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
|
|||
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
||||
|
||||
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
|
|||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
from .agent import DiplomacyAgent
|
||||
from .clients import load_model_client
|
||||
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
|
||||
from .utils import gather_possible_orders, normalize_recipient_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .game_history import GameHistory
|
||||
|
|
@ -21,9 +20,9 @@ load_dotenv()
|
|||
|
||||
|
||||
async def conduct_negotiations(
|
||||
game: 'Game',
|
||||
game: "Game",
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: 'GameHistory',
|
||||
game_history: "GameHistory",
|
||||
model_error_stats: Dict[str, Dict[str, int]],
|
||||
log_file_path: str,
|
||||
max_rounds: int = 3,
|
||||
|
|
@ -35,12 +34,8 @@ async def conduct_negotiations(
|
|||
"""
|
||||
logger.info("Starting negotiation phase.")
|
||||
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
eliminated_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
||||
]
|
||||
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||
|
||||
logger.info(f"Active powers for negotiations: {active_powers}")
|
||||
if eliminated_powers:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from typing import Dict
|
||||
|
||||
from .clients import load_model_client
|
||||
from .game_history import GameHistory
|
||||
from .agent import DiplomacyAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def planning_phase(
|
||||
game,
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
|
|
@ -20,12 +19,8 @@ async def planning_phase(
|
|||
Lets each power generate a strategic plan using their DiplomacyAgent.
|
||||
"""
|
||||
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
eliminated_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
||||
]
|
||||
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||
|
||||
logger.info(f"Active powers for planning: {active_powers}")
|
||||
if eliminated_powers:
|
||||
|
|
@ -35,9 +30,7 @@ async def planning_phase(
|
|||
|
||||
board_state = game.get_state()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
|
||||
futures = {}
|
||||
for power_name in active_powers:
|
||||
if power_name not in agents:
|
||||
|
|
@ -71,16 +64,14 @@ async def planning_phase(
|
|||
if plan_result.startswith("Error:"):
|
||||
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
||||
if power_name in model_error_stats:
|
||||
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
|
||||
model_error_stats[power_name]['planning_generation_errors'] += 1
|
||||
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
|
||||
model_error_stats[power_name]["planning_generation_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
|
||||
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
|
||||
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
|
||||
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
|
||||
elif plan_result:
|
||||
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
||||
game_history.add_plan(
|
||||
game.current_short_phase, power_name, plan_result
|
||||
)
|
||||
game_history.add_plan(game.current_short_phase, power_name, plan_result)
|
||||
logger.debug(f"Added plan for {power_name} to history.")
|
||||
else:
|
||||
logger.warning(f"Agent {power_name} returned an empty plan.")
|
||||
|
|
@ -88,11 +79,11 @@ async def planning_phase(
|
|||
except Exception as e:
|
||||
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
||||
if power_name in model_error_stats:
|
||||
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
|
||||
model_error_stats[power_name]['planning_execution_errors'] += 1
|
||||
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
|
||||
model_error_stats[power_name]["planning_execution_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
|
||||
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
|
||||
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
|
||||
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
|
||||
|
||||
logger.info("Planning phase processing complete.")
|
||||
return game_history
|
||||
|
|
@ -14,6 +14,7 @@ import re
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
||||
|
|
@ -27,9 +28,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
|||
if len(loc.split("/")[0]) == 3
|
||||
}
|
||||
|
||||
graph: Dict[str, Dict[str, List[str]]] = {
|
||||
p: {"ARMY": [], "FLEET": []} for p in provs
|
||||
}
|
||||
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
|
||||
|
||||
# ── helper: list every concrete variant of a province ─────────────────
|
||||
def variants(code: str) -> List[str]:
|
||||
|
|
@ -55,11 +54,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
|||
graph[src]["ARMY"].append(dest)
|
||||
|
||||
# FLEET — any src variant that can host a fleet is valid
|
||||
if any(
|
||||
game_map.abuts("F", sv, "-", dv)
|
||||
for sv in src_vers
|
||||
for dv in dest_vers
|
||||
):
|
||||
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
|
||||
graph[src]["FLEET"].append(dest)
|
||||
|
||||
# ── tidy up duplicates / order ---------------------------------------
|
||||
|
|
@ -70,24 +65,23 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
|||
return graph
|
||||
|
||||
|
||||
|
||||
def bfs_shortest_path(
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
board_state: BoardState,
|
||||
game_map: GameMap, # Added game_map
|
||||
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
||||
unit_type: str,
|
||||
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
|
||||
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
|
||||
) -> Optional[List[str]]: # Returns path of SHORT names
|
||||
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
||||
|
||||
# Convert full start location to short province name
|
||||
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
||||
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||
start_loc_short = start_loc_short[:3]
|
||||
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
||||
# A simpler way for non-coastal full (like 'VIE') or already short:
|
||||
if '/' not in start_loc_full:
|
||||
if "/" not in start_loc_full:
|
||||
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
||||
else: # Has '/', e.g. 'STP/SC'
|
||||
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
||||
|
|
@ -119,10 +113,11 @@ def bfs_shortest_path(
|
|||
queue.append((next_loc_short, new_path))
|
||||
return None
|
||||
|
||||
|
||||
# --- Helper functions for context generation ---
|
||||
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
||||
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
||||
for power, unit_list in board_state.get('units', {}).items():
|
||||
for power, unit_list in board_state.get("units", {}).items():
|
||||
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
||||
parts = unit_str.split(" ")
|
||||
if len(parts) == 2:
|
||||
|
|
@ -131,24 +126,26 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
|
|||
return f"{parts[0]} {location} ({power})"
|
||||
return None
|
||||
|
||||
|
||||
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
||||
"""Returns the controlling power's name if the location is an SC, else None."""
|
||||
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
||||
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
||||
if loc_province_name not in game_map.scs:
|
||||
return None
|
||||
for power, sc_list in board_state.get('centers', {}).items():
|
||||
for power, sc_list in board_state.get("centers", {}).items():
|
||||
if loc_province_name in sc_list:
|
||||
return power
|
||||
return None # Unowned SC
|
||||
|
||||
|
||||
def get_shortest_path_to_friendly_unit(
|
||||
board_state: BoardState,
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
game_map: GameMap, # Added game_map
|
||||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str
|
||||
start_unit_type: str,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
"""Finds the shortest path to any friendly unit of the same power."""
|
||||
|
||||
|
|
@ -183,7 +180,7 @@ def get_nearest_enemy_units(
|
|||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str,
|
||||
n: int = 3
|
||||
n: int = 3,
|
||||
) -> List[Tuple[str, List[str]]]:
|
||||
"""Finds up to N nearest enemy units, sorted by path length."""
|
||||
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
||||
|
|
@ -205,9 +202,9 @@ def get_nearest_enemy_units(
|
|||
|
||||
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
||||
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
||||
if '/' in target_enemy_loc_short:
|
||||
if "/" in target_enemy_loc_short:
|
||||
target_enemy_loc_short = target_enemy_loc_short[:3]
|
||||
if '/' not in target_enemy_loc_full:
|
||||
if "/" not in target_enemy_loc_full:
|
||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||
else:
|
||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||
|
|
@ -281,7 +278,7 @@ def get_adjacent_territory_details(
|
|||
board_state: BoardState,
|
||||
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
||||
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
||||
graph: Dict[str, Dict[str, List[str]]]
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""Generates a string describing adjacent territories and units that can interact with them."""
|
||||
output_lines: List[str] = []
|
||||
|
|
@ -291,9 +288,9 @@ def get_adjacent_territory_details(
|
|||
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
||||
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
||||
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
||||
if '/' in unit_loc_short:
|
||||
if "/" in unit_loc_short:
|
||||
unit_loc_short = unit_loc_short[:3]
|
||||
if '/' not in unit_loc_full:
|
||||
if "/" not in unit_loc_full:
|
||||
unit_loc_short = unit_loc_full[:3]
|
||||
else:
|
||||
unit_loc_short = unit_loc_full[:3]
|
||||
|
|
@ -308,11 +305,11 @@ def get_adjacent_territory_details(
|
|||
continue
|
||||
processed_adj_provinces.add(adj_loc_short)
|
||||
|
||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
|
||||
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
|
||||
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
|
||||
elif adj_loc_type == 'WATER':
|
||||
adj_loc_type_display = 'WATER'
|
||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
|
||||
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
|
||||
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
|
||||
elif adj_loc_type == "WATER":
|
||||
adj_loc_type_display = "WATER"
|
||||
else: # SHUT etc.
|
||||
adj_loc_type_display = adj_loc_type
|
||||
|
||||
|
|
@ -330,8 +327,7 @@ def get_adjacent_territory_details(
|
|||
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
||||
# A true "can support/move to" would require checking possible orders of those further units.
|
||||
# further_adj_provinces are short names from the graph
|
||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
|
||||
graph.get(adj_loc_short, {}).get('FLEET', [])
|
||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
|
||||
|
||||
supporting_units_info = []
|
||||
processed_further_provinces = set()
|
||||
|
|
@ -388,12 +384,12 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
continue
|
||||
|
||||
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
||||
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
|
||||
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
|
||||
|
||||
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
||||
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
||||
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
|
||||
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
|
||||
if loc_type_short == "COAST" or loc_type_short == "LAND":
|
||||
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
|
||||
else:
|
||||
loc_type_display = loc_type_short
|
||||
|
||||
|
|
@ -401,7 +397,7 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
||||
|
||||
# Unit Information section
|
||||
current_unit_lines.append(' <UnitInformation>')
|
||||
current_unit_lines.append(" <UnitInformation>")
|
||||
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
||||
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
||||
if sc_owner_at_loc == power_name:
|
||||
|
|
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
||||
current_unit_lines.append(f" {header_content}")
|
||||
current_unit_lines.append(f" Units present: {unit_str_full}")
|
||||
current_unit_lines.append(' </UnitInformation>')
|
||||
current_unit_lines.append(" </UnitInformation>")
|
||||
|
||||
# Possible moves section
|
||||
current_unit_lines.append(' <PossibleMoves>')
|
||||
current_unit_lines.append(" <PossibleMoves>")
|
||||
current_unit_lines.append(" Possible moves:")
|
||||
for order_str in unit_specific_possible_orders:
|
||||
current_unit_lines.append(f" {order_str}")
|
||||
current_unit_lines.append(' </PossibleMoves>')
|
||||
current_unit_lines.append(" </PossibleMoves>")
|
||||
|
||||
# Nearest enemy units section
|
||||
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
||||
current_unit_lines.append(' <NearestEnemyUnits>')
|
||||
current_unit_lines.append(" <NearestEnemyUnits>")
|
||||
if enemy_units_info:
|
||||
current_unit_lines.append(" Nearest units (not ours):")
|
||||
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
||||
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
|
||||
current_unit_lines.append(
|
||||
f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
|
||||
)
|
||||
else:
|
||||
current_unit_lines.append(" Nearest units (not ours): None found")
|
||||
current_unit_lines.append(' </NearestEnemyUnits>')
|
||||
current_unit_lines.append(" </NearestEnemyUnits>")
|
||||
|
||||
# Nearest supply centers (not controlled by us) section
|
||||
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
||||
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
|
||||
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
|
||||
if uncontrolled_scs_info:
|
||||
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
||||
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
||||
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
|
||||
current_unit_lines.append(
|
||||
f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
|
||||
)
|
||||
else:
|
||||
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
||||
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
|
||||
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
|
||||
|
||||
# Adjacent territories details section
|
||||
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
||||
current_unit_lines.append(' <AdjacentTerritories>')
|
||||
current_unit_lines.append(" <AdjacentTerritories>")
|
||||
if adj_details_str:
|
||||
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
||||
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
||||
# We might need to indent adj_details_str if it's a single block of text
|
||||
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
||||
if '\n' in adj_details_str:
|
||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
|
||||
if "\n" in adj_details_str:
|
||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
|
||||
current_unit_lines.append(indented_adj_details)
|
||||
else:
|
||||
current_unit_lines.append(f" {adj_details_str}")
|
||||
else:
|
||||
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
|
||||
current_unit_lines.append(' </AdjacentTerritories>')
|
||||
current_unit_lines.append(
|
||||
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
|
||||
) # Added more descriptive else
|
||||
current_unit_lines.append(" </AdjacentTerritories>")
|
||||
|
||||
current_unit_lines.append(' </UnitContext>')
|
||||
current_unit_lines.append(" </UnitContext>")
|
||||
final_context_lines.extend(current_unit_lines)
|
||||
|
||||
final_context_lines.append("</PossibleOrdersContext>")
|
||||
|
|
@ -469,7 +471,6 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
# Regex and tiny helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
|
||||
# ── order-syntax matchers ─────────────────────────────────────────────────
|
||||
|
|
@ -478,6 +479,7 @@ _HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
|||
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
||||
|
||||
|
||||
def _is_hold_order(order: str) -> bool: # NEW
|
||||
return bool(_HOLD_RE.match(order.strip()))
|
||||
|
||||
|
|
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
|
|||
"""Trim & uppercase for reliable comparisons."""
|
||||
return name.strip().upper()
|
||||
|
||||
|
||||
def _is_simple_move(order: str) -> bool:
|
||||
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _is_retreat_order(order: str) -> bool:
|
||||
return bool(_RETREAT_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _is_adjust_order(order: str) -> bool:
|
||||
return bool(_ADJUST_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _split_move(order: str) -> Tuple[str, str]:
|
||||
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
||||
unit_part, dest = order.split(" - ")
|
||||
|
|
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
|
|||
# Gather *all* friendly support orders for a given move
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _all_support_examples(
|
||||
mover: str,
|
||||
dest: str,
|
||||
|
|
@ -526,6 +533,7 @@ def _all_support_examples(
|
|||
|
||||
return supports
|
||||
|
||||
|
||||
def _all_support_hold_examples(
|
||||
holder: str,
|
||||
all_orders: Dict[str, List[str]],
|
||||
|
|
@ -550,6 +558,7 @@ def _all_support_hold_examples(
|
|||
# Province-type resolver (handles short codes, coasts, seas)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _province_type_display(game_map, prov_short: str) -> str:
|
||||
"""
|
||||
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
||||
|
|
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
|
|||
# Adjacent-territory lines (used by movement-phase builder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _adjacent_territory_lines(
|
||||
graph,
|
||||
game_map,
|
||||
|
|
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
|
|||
pwr = unit_here.split(" ")[-1].strip("()")
|
||||
if pwr == our_power:
|
||||
friend_descr = unit_here.split(" (")[0]
|
||||
lines.append(
|
||||
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
|
||||
)
|
||||
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
|
||||
else:
|
||||
lines.append(
|
||||
f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa"
|
||||
)
|
||||
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa")
|
||||
|
||||
return lines
|
||||
|
||||
|
|
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
|
|||
# Movement-phase generator (UNCHANGED LOGIC)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_movement(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
|
|||
if _norm_power(unit_power) != me:
|
||||
continue # Skip units that aren’t ours
|
||||
|
||||
mover_descr, _ = _split_move(
|
||||
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
|
||||
)
|
||||
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
|
||||
|
||||
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
||||
prov_type_disp = _province_type_display(game_map, prov_short)
|
||||
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
||||
|
||||
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
||||
owner_line = (
|
||||
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||
)
|
||||
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||
|
||||
ind = " "
|
||||
block: List[str] = [f"<Territory {prov_short}>"]
|
||||
|
|
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
|
|||
|
||||
# ----- Adjacent territories -----
|
||||
block.append("# Adjacent territories:")
|
||||
block.extend(
|
||||
_adjacent_territory_lines(
|
||||
graph, game_map, board_state,
|
||||
unit_loc_full, mover_descr, power_name
|
||||
)
|
||||
)
|
||||
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
|
||||
|
||||
# ----- Nearest enemy units -----
|
||||
block.append("# Nearest units (not ours):")
|
||||
enemies = get_nearest_enemy_units(
|
||||
board_state, graph, game_map,
|
||||
power_name, unit_loc_full,
|
||||
board_state,
|
||||
graph,
|
||||
game_map,
|
||||
power_name,
|
||||
unit_loc_full,
|
||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||
n=3,
|
||||
)
|
||||
|
|
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
|
|||
# ----- Nearest uncontrolled SCs -----
|
||||
block.append("# Nearest supply centers (not controlled by us):")
|
||||
scs = get_nearest_uncontrolled_scs(
|
||||
game_map, board_state, graph,
|
||||
power_name, unit_loc_full,
|
||||
game_map,
|
||||
board_state,
|
||||
graph,
|
||||
power_name,
|
||||
unit_loc_full,
|
||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||
n=3,
|
||||
)
|
||||
|
|
@ -737,9 +741,7 @@ def _generate_rich_order_context_movement(
|
|||
# ---- Moves (same behaviour as before) ----
|
||||
for mv in simple_moves:
|
||||
mover, dest = _split_move(mv)
|
||||
occ = _dest_occupancy_desc(
|
||||
dest.split("/")[0][:3], game_map, board_state, power_name
|
||||
)
|
||||
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
|
||||
block.append(f"{ind}{mv} {occ}")
|
||||
|
||||
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
||||
|
|
@ -753,7 +755,6 @@ def _generate_rich_order_context_movement(
|
|||
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
||||
block.append(f"{ind * 2}Available Support: {s}")
|
||||
|
||||
|
||||
block.append(f"</Territory {prov_short}>")
|
||||
blocks.append("\n".join(block))
|
||||
|
||||
|
|
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
|
|||
# Retreat-phase builder – echo orders verbatim, no tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_retreat(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
|
|||
return "\n".join(lines) if lines else "(No dislodged units)"
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_adjustment(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
|
|||
return "\n".join(lines) if len(lines) > 1 else summary
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase-dispatch wrapper (public entry point)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_rich_order_context(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -847,21 +847,13 @@ def generate_rich_order_context(
|
|||
phase_type = game.current_short_phase[-1]
|
||||
|
||||
if phase_type == "M": # Movement
|
||||
return _generate_rich_order_context_movement(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||
|
||||
if phase_type == "R": # Retreat
|
||||
return _generate_rich_order_context_retreat(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
|
||||
|
||||
if phase_type == "A": # Adjustment (build / disband)
|
||||
return _generate_rich_order_context_adjustment(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
|
||||
|
||||
# Fallback – treat unknown formats as movement
|
||||
return _generate_rich_order_context_movement(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from .possible_order_context import (
|
|||
generate_rich_order_context,
|
||||
generate_rich_order_context_xml,
|
||||
)
|
||||
import os
|
||||
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
||||
|
||||
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
||||
|
|
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
|
|||
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
||||
|
||||
return final_prompt
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def atomic_write_json(data: dict, filepath: str):
|
|||
|
||||
# Write to a temporary file in the same directory
|
||||
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
||||
with open(temp_filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
# Atomically rename the temporary file to the final destination
|
||||
|
|
@ -110,11 +110,7 @@ def get_special_models() -> Dict[str, str]:
|
|||
# Use a different Gemini model for formatting
|
||||
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
||||
"""
|
||||
return {
|
||||
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
|
||||
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
|
||||
}
|
||||
|
||||
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
|
||||
|
||||
|
||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||
|
|
@ -175,15 +171,19 @@ async def get_valid_orders(
|
|||
|
||||
# ── 2. Type check ──────────────────────────────────────────
|
||||
if not isinstance(raw_orders, list):
|
||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
|
||||
power_name, raw_orders)
|
||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||||
|
||||
# ── 3. Round-trip validation with engine ───────────────────
|
||||
CODE_TO_ENGINE = {
|
||||
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
|
||||
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
|
||||
"AUT": "AUSTRIA",
|
||||
"ENG": "ENGLAND",
|
||||
"FRA": "FRANCE",
|
||||
"GER": "GERMANY",
|
||||
"ITA": "ITALY",
|
||||
"RUS": "RUSSIA",
|
||||
"TUR": "TURKEY",
|
||||
}
|
||||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||||
|
||||
|
|
@ -226,8 +226,6 @@ async def get_valid_orders(
|
|||
return {"valid": valid, "invalid": invalid}
|
||||
|
||||
|
||||
|
||||
|
||||
def normalize_and_compare_orders(
|
||||
issued_orders: Dict[str, List[str]],
|
||||
accepted_orders_dict: Dict[str, List[str]],
|
||||
|
|
@ -348,8 +346,6 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
# == New LLM Response Logging Function ==
|
||||
def log_llm_response(
|
||||
log_file_path: str,
|
||||
|
|
@ -374,14 +370,18 @@ def log_llm_response(
|
|||
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||||
# Added "raw_input" to fieldnames
|
||||
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
|
||||
writer = csv.DictWriter(
|
||||
csvfile,
|
||||
fieldnames=fieldnames,
|
||||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||||
escapechar='\\') # Use backslash for escaping
|
||||
escapechar="\\",
|
||||
) # Use backslash for escaping
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader() # Write header only if file is new
|
||||
|
||||
writer.writerow({
|
||||
writer.writerow(
|
||||
{
|
||||
"model": model_name,
|
||||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||||
"phase": phase,
|
||||
|
|
@ -389,12 +389,12 @@ def log_llm_response(
|
|||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||||
"raw_response": raw_response,
|
||||
"success": success,
|
||||
})
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
async def run_llm_and_log(
|
||||
client: "BaseModelClient",
|
||||
prompt: str,
|
||||
|
|
@ -428,8 +428,7 @@ async def run_llm_and_log(
|
|||
except Exception as e:
|
||||
if attempt == attempts - 1:
|
||||
logger.error(
|
||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
|
||||
f"in phase {phase}: {e}",
|
||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Back-off before the next attempt (unless this was the last)
|
||||
|
|
@ -439,7 +438,6 @@ async def run_llm_and_log(
|
|||
return raw_response
|
||||
|
||||
|
||||
|
||||
# This generates a few lines of random alphanum chars to inject into the
|
||||
# system prompt. This lets us use temp=0 while still getting variation
|
||||
# between trials.
|
||||
|
|
@ -447,15 +445,8 @@ async def run_llm_and_log(
|
|||
# ensure valid json outputs.
|
||||
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
||||
# Generate x lines of y random alphanumeric characters
|
||||
seed_lines = [
|
||||
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
|
||||
for _ in range(n_lines)
|
||||
]
|
||||
random_seed_block = (
|
||||
"<RANDOM SEED PLEASE IGNORE>\n" +
|
||||
"\n".join(seed_lines) +
|
||||
"\n</RANDOM SEED>"
|
||||
)
|
||||
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
|
||||
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
|
||||
return random_seed_block
|
||||
|
||||
|
||||
|
|
@ -473,6 +464,7 @@ def get_prompt_path(prompt_name: str) -> str:
|
|||
else:
|
||||
return prompt_name
|
||||
|
||||
|
||||
def normalize_recipient_name(recipient: str) -> str:
|
||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||
if not recipient:
|
||||
|
|
@ -482,24 +474,24 @@ def normalize_recipient_name(recipient: str) -> str:
|
|||
|
||||
# Handle common LLM typos and abbreviations found in data
|
||||
name_mapping = {
|
||||
'EGMANY': 'GERMANY',
|
||||
'GERMAN': 'GERMANY',
|
||||
'UK': 'ENGLAND',
|
||||
'BRIT': 'ENGLAND',
|
||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
||||
'FRANCE': 'FRANCE', # Keep as-is
|
||||
'GERMANY': 'GERMANY', # Keep as-is
|
||||
'ITALY': 'ITALY', # Keep as-is
|
||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
||||
'TURKEY': 'TURKEY', # Keep as-is
|
||||
'Germany': 'GERMANY',
|
||||
'England': 'ENGLAND',
|
||||
'France': 'FRANCE',
|
||||
'Italy': 'ITALY',
|
||||
'Russia': 'RUSSIA',
|
||||
'Austria': 'AUSTRIA',
|
||||
'Turkey': 'TURKEY',
|
||||
"EGMANY": "GERMANY",
|
||||
"GERMAN": "GERMANY",
|
||||
"UK": "ENGLAND",
|
||||
"BRIT": "ENGLAND",
|
||||
"ENGLAND": "ENGLAND", # Keep as-is
|
||||
"FRANCE": "FRANCE", # Keep as-is
|
||||
"GERMANY": "GERMANY", # Keep as-is
|
||||
"ITALY": "ITALY", # Keep as-is
|
||||
"AUSTRIA": "AUSTRIA", # Keep as-is
|
||||
"RUSSIA": "RUSSIA", # Keep as-is
|
||||
"TURKEY": "TURKEY", # Keep as-is
|
||||
"Germany": "GERMANY",
|
||||
"England": "ENGLAND",
|
||||
"France": "FRANCE",
|
||||
"Italy": "ITALY",
|
||||
"Russia": "RUSSIA",
|
||||
"Austria": "AUSTRIA",
|
||||
"Turkey": "TURKEY",
|
||||
}
|
||||
|
||||
normalized = name_mapping.get(recipient, recipient)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue