mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-28 17:29:41 +00:00
dramatically improving logging thanks to new 3.7sonnet cursor agent mode
This commit is contained in:
parent
b54a8252d6
commit
eb3de01956
10 changed files with 1025 additions and 219 deletions
21
CLAUDE.md
Normal file
21
CLAUDE.md
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# AI Diplomacy Development Guide
|
||||
|
||||
## Commands
|
||||
- Run game: `python lm_game.py --max_year 1910 --summary_model "gpt-4o-mini" --num_negotiation_rounds 3`
|
||||
- Run tests: `pytest -v diplomacy/tests/` or `pytest -v -k test_name`
|
||||
- Run specific test: `pytest -v diplomacy/tests/path_to_test.py::test_function`
|
||||
- Lint: `pylint diplomacy/path/to/file.py`
|
||||
- Full test suite: `./diplomacy/run_tests.sh`
|
||||
|
||||
## Code Style
|
||||
- Use Python type hints for function parameters and return values
|
||||
- Follow PEP 8 naming: snake_case for functions/variables, UPPER_CASE for constants
|
||||
- Organize imports: standard library, third-party, local modules
|
||||
- Error handling: Use specific exceptions with informative messages
|
||||
- Docstrings: Use multi-line docstrings with parameter descriptions
|
||||
- Keep functions focused on a single responsibility
|
||||
- Models/LLM clients inherit from BaseModelClient and implement required methods
|
||||
- When possible, use concurrent operations (see concurrent.futures in lm_game.py)
|
||||
|
||||
## Environment
|
||||
Python 3.5+ required. Use virtual environment with requirements.txt.
|
||||
|
|
@ -22,11 +22,89 @@ from .game_history import GameHistory
|
|||
from .long_story_short import get_optimized_context
|
||||
from .model_loader import load_model_client
|
||||
|
||||
# set logger back to just info
|
||||
# Configure logger with a more useful format
|
||||
logger = logging.getLogger("client")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Function to configure logging options
|
||||
def configure_logging(
|
||||
log_full_prompts=True,
|
||||
log_full_responses=True,
|
||||
suppress_connection_logs=True,
|
||||
log_level=logging.INFO
|
||||
):
|
||||
"""
|
||||
Configure the logging system for AI Diplomacy
|
||||
|
||||
Parameters:
|
||||
- log_full_prompts: Whether to log the full prompts sent to models
|
||||
- log_full_responses: Whether to log the full responses from models
|
||||
- suppress_connection_logs: Whether to suppress HTTP connection logs
|
||||
- log_level: The overall logging level for the application
|
||||
"""
|
||||
# Configure root logger
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
# Set client logger level
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Configure specific loggers based on parameters
|
||||
if suppress_connection_logs:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("anthropic").setLevel(logging.WARNING)
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
|
||||
# Set module-level configuration
|
||||
global SHOULD_LOG_FULL_PROMPTS, SHOULD_LOG_FULL_RESPONSES
|
||||
SHOULD_LOG_FULL_PROMPTS = log_full_prompts
|
||||
SHOULD_LOG_FULL_RESPONSES = log_full_responses
|
||||
|
||||
logger.info(f"Logging configured: full_prompts={log_full_prompts}, full_responses={log_full_responses}, level={logging.getLevelName(log_level)}")
|
||||
|
||||
# Initialize defaults
|
||||
SHOULD_LOG_FULL_PROMPTS = True
|
||||
SHOULD_LOG_FULL_RESPONSES = True
|
||||
|
||||
# Helper function for truncating long outputs in logs
|
||||
def _truncate_text(text, max_length=500):
|
||||
"""Truncate text for logging purposes with indicator of original length"""
|
||||
if not text or len(text) <= max_length:
|
||||
return text
|
||||
return f"{text[:max_length]}... [truncated, total length: {len(text)} chars]"
|
||||
|
||||
# Helper function to log full model responses
|
||||
def _log_full_response(model_type, model_name, power_name, response):
|
||||
"""Logs the full model response at INFO level"""
|
||||
if not response or not SHOULD_LOG_FULL_RESPONSES:
|
||||
return
|
||||
|
||||
border = "=" * 80
|
||||
logger.info(f"\nMODEL RESPONSE | {model_type} | {model_name} | {power_name or 'Unknown'}\n{border}")
|
||||
logger.info(f"{response}")
|
||||
logger.info(f"{border}\n")
|
||||
|
||||
# Helper function to log prompt details
|
||||
def _log_prompt_details(model_type, model_name, power_name, prompt, system_prompt=None):
|
||||
"""Logs the prompt details at INFO level"""
|
||||
if not prompt or not SHOULD_LOG_FULL_PROMPTS:
|
||||
return
|
||||
|
||||
border = "=" * 80
|
||||
total_tokens = len(prompt.split())
|
||||
|
||||
if system_prompt:
|
||||
system_tokens = len(system_prompt.split())
|
||||
logger.info(f"\nPROMPT | {model_type} | {model_name} | {power_name or 'Unknown'} | ~{total_tokens} tokens (user) + ~{system_tokens} tokens (system)\n{border}")
|
||||
logger.debug(f"System prompt: {_truncate_text(system_prompt)}")
|
||||
else:
|
||||
logger.info(f"\nPROMPT | {model_type} | {model_name} | {power_name or 'Unknown'} | ~{total_tokens} tokens\n{border}")
|
||||
|
||||
logger.debug(f"User prompt: {_truncate_text(prompt)}")
|
||||
logger.info(f"{border}\n")
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
|
@ -53,7 +131,7 @@ class BaseModelClient:
|
|||
try:
|
||||
self.system_prompt = load_prompt(f"{self.power_name.lower()}_system_prompt.txt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"No specific system prompt found for {self.power_name}; using default.")
|
||||
logger.warning(f"CONFIG | {self.model_name} | No specific system prompt for {self.power_name}, using default")
|
||||
self.system_prompt = load_prompt("system_prompt.txt")
|
||||
else:
|
||||
self.system_prompt = load_prompt("system_prompt.txt")
|
||||
|
|
@ -137,7 +215,7 @@ class BaseModelClient:
|
|||
possible_orders_text = format_possible_orders(game, possible_orders)
|
||||
|
||||
# 8) Convoy Paths
|
||||
logger.debug(f"convoy_paths_possible is: {game.convoy_paths_possible}")
|
||||
logger.debug(f"CONTEXT | {self.model_name} | {power_name} | Convoy paths: {len(game.convoy_paths_possible) if game.convoy_paths_possible else 0} available")
|
||||
convoy_paths_text = format_convoy_paths(game, game.convoy_paths_possible, power_name)
|
||||
|
||||
# 9) Threat Assessment
|
||||
|
|
@ -233,8 +311,8 @@ class BaseModelClient:
|
|||
|
||||
try:
|
||||
raw_response = self.generate_response(prompt)
|
||||
logger.info(
|
||||
f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}"
|
||||
logger.debug(
|
||||
f"ORDERS | {self.model_name} | {power_name} | Raw response: {_truncate_text(raw_response)}"
|
||||
)
|
||||
|
||||
# Attempt to parse the final "orders" from the LLM
|
||||
|
|
@ -242,7 +320,7 @@ class BaseModelClient:
|
|||
|
||||
if not move_list:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
|
||||
f"PARSE_ERROR | {self.model_name} | {power_name} | Failed to extract moves, using fallback"
|
||||
)
|
||||
if model_error_stats is not None:
|
||||
# forcibly convert sets to string
|
||||
|
|
@ -255,7 +333,7 @@ class BaseModelClient:
|
|||
return validated_moves
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.model_name}] LLM error for {power_name}: {e}")
|
||||
logger.error(f"LLM_ERROR | {self.model_name} | {power_name} | {str(e)}")
|
||||
if model_error_stats is not None:
|
||||
# forcibly convert sets to string
|
||||
model_name_for_stats = str(self.model_name)
|
||||
|
|
@ -280,7 +358,7 @@ class BaseModelClient:
|
|||
if not matches:
|
||||
# Some LLMs might not put the colon or might have triple backtick fences.
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
|
||||
f"PARSE | {self.model_name} | {power_name} | Regex #1 failed, trying alternative patterns"
|
||||
)
|
||||
|
||||
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
||||
|
|
@ -289,16 +367,17 @@ class BaseModelClient:
|
|||
|
||||
if not matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying triple-backtick code fences."
|
||||
f"PARSE | {self.model_name} | {power_name} | Regex #2 failed, trying triple-backtick code fences"
|
||||
)
|
||||
|
||||
# 2) If still no match, check for triple-backtick code fences containing JSON
|
||||
if not matches:
|
||||
code_fence_pattern = r"```json\s*(\{.*?\})\s*```"
|
||||
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
|
||||
f"PARSE | {self.model_name} | {power_name} | Found triple-backtick JSON block"
|
||||
)
|
||||
|
||||
# 3) Attempt to parse JSON if we found anything
|
||||
|
|
@ -310,7 +389,7 @@ class BaseModelClient:
|
|||
if captured.startswith("{{") and captured.endswith("}}"):
|
||||
# remove ONE leading '{' and ONE trailing '}'
|
||||
# so {{ "orders": [...] }} becomes { "orders": [...] }
|
||||
logger.debug(f"[{self.model_name}] Detected double braces for {power_name}, trimming extra braces.")
|
||||
logger.debug(f"PARSE | {self.model_name} | {power_name} | Detected double braces, trimming extra braces")
|
||||
# strip exactly one brace pair
|
||||
trimmed = captured[1:-1].strip()
|
||||
json_text = trimmed
|
||||
|
|
@ -323,7 +402,7 @@ class BaseModelClient:
|
|||
|
||||
if not json_text:
|
||||
logger.debug(
|
||||
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
|
||||
f"PARSE | {self.model_name} | {power_name} | No JSON text found in response"
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -333,7 +412,7 @@ class BaseModelClient:
|
|||
return data.get("orders", None)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying bracket fallback."
|
||||
f"PARSE | {self.model_name} | {power_name} | JSON decode failed: {str(e)[:100]}. Trying bracket fallback"
|
||||
)
|
||||
|
||||
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
||||
|
|
@ -346,10 +425,11 @@ class BaseModelClient:
|
|||
raw_list_str = "[" + bracket_match.group(1).strip() + "]"
|
||||
moves = ast.literal_eval(raw_list_str)
|
||||
if isinstance(moves, list):
|
||||
logger.debug(f"PARSE | {self.model_name} | {power_name} | Bracket fallback successful")
|
||||
return moves
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
|
||||
f"PARSE | {self.model_name} | {power_name} | Bracket fallback failed: {str(e2)[:100]}"
|
||||
)
|
||||
|
||||
# If all attempts failed
|
||||
|
|
@ -361,12 +441,12 @@ class BaseModelClient:
|
|||
"""
|
||||
Filter out invalid moves, fill missing with HOLD, else fallback.
|
||||
"""
|
||||
logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}")
|
||||
logger.debug(f"VALIDATE | {self.model_name} | Validating {len(moves)} proposed moves")
|
||||
validated = []
|
||||
used_locs = set()
|
||||
|
||||
if not isinstance(moves, list):
|
||||
logger.debug(f"[{self.model_name}] Moves not a list, fallback.")
|
||||
logger.debug(f"VALIDATE | {self.model_name} | Moves not a list type, using fallback")
|
||||
return self.fallback_orders(possible_orders)
|
||||
|
||||
for move in moves:
|
||||
|
|
@ -378,7 +458,7 @@ class BaseModelClient:
|
|||
if len(parts) >= 2:
|
||||
used_locs.add(parts[1][:3])
|
||||
else:
|
||||
logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}")
|
||||
logger.debug(f"VALIDATE | {self.model_name} | Invalid move: {move_str}")
|
||||
|
||||
# Fill missing with hold
|
||||
for loc, orders_list in possible_orders.items():
|
||||
|
|
@ -389,10 +469,10 @@ class BaseModelClient:
|
|||
)
|
||||
|
||||
if not validated:
|
||||
logger.warning(f"[{self.model_name}] All moves invalid, fallback.")
|
||||
logger.warning(f"VALIDATE | {self.model_name} | All moves invalid, using fallback")
|
||||
return self.fallback_orders(possible_orders)
|
||||
|
||||
logger.debug(f"[{self.model_name}] Validated moves: {validated}")
|
||||
logger.debug(f"VALIDATE | {self.model_name} | Final valid moves: {len(validated)}")
|
||||
return validated
|
||||
|
||||
def fallback_orders(self, possible_orders: Dict[str, List[str]]) -> List[str]:
|
||||
|
|
@ -461,7 +541,7 @@ class BaseModelClient:
|
|||
if not json_matches:
|
||||
# try normal
|
||||
logger.debug(
|
||||
f"[{self.model_name}] No JSON block found in LLM response for {power_name}. Trying double braces."
|
||||
f"CHAT | {self.model_name} | {power_name} | No JSON block, trying double braces"
|
||||
)
|
||||
json_matches = re.findall(
|
||||
r"PARSABLE OUTPUT:\s*\{(.*?)\}", raw_response, re.DOTALL
|
||||
|
|
@ -470,7 +550,7 @@ class BaseModelClient:
|
|||
if not json_matches:
|
||||
# try backtick fences
|
||||
logger.debug(
|
||||
f"[{self.model_name}] No JSON block found in LLM response for {power_name}. Trying backtick fences."
|
||||
f"CHAT | {self.model_name} | {power_name} | Trying backtick fences"
|
||||
)
|
||||
json_matches = re.findall(
|
||||
r"```json\n(.*?)\n```", raw_response, re.DOTALL
|
||||
|
|
@ -487,23 +567,24 @@ class BaseModelClient:
|
|||
message_type = message_data.get("message_type", "global")
|
||||
content = message_data.get("content", "").strip()
|
||||
recipient = message_data.get("recipient", GLOBAL)
|
||||
recipient = recipient.upper()
|
||||
|
||||
# Validate recipient if private message
|
||||
if message_type == "private" and recipient not in active_powers:
|
||||
logger.warning(
|
||||
f"Invalid recipient {recipient} for private message, defaulting to GLOBAL"
|
||||
f"CHAT | {self.model_name} | {power_name} | Invalid recipient '{recipient}', defaulting to GLOBAL"
|
||||
)
|
||||
recipient = GLOBAL
|
||||
|
||||
# For private messages, ensure recipient is specified
|
||||
if message_type == "private" and recipient == GLOBAL:
|
||||
logger.warning(
|
||||
"Private message without recipient specified, defaulting to GLOBAL"
|
||||
f"CHAT | {self.model_name} | {power_name} | Private message without recipient, defaulting to GLOBAL"
|
||||
)
|
||||
|
||||
# Log for debugging
|
||||
logger.info(
|
||||
f"Power {power_name} sends {message_type} message to {recipient}"
|
||||
f"CHAT | {self.model_name} | {power_name} | Sending {message_type} message to {recipient}"
|
||||
)
|
||||
|
||||
# Keep local record for building future conversation context
|
||||
|
|
@ -519,7 +600,7 @@ class BaseModelClient:
|
|||
message = None
|
||||
|
||||
except AttributeError:
|
||||
logger.error("Error parsing raw response")
|
||||
logger.error(f"CHAT | {self.model_name} | {power_name} | Error parsing raw response")
|
||||
|
||||
# Deduplicate messages
|
||||
messages = list(set([json.dumps(m) for m in messages]))
|
||||
|
|
@ -543,29 +624,34 @@ class OpenAIClient(BaseModelClient):
|
|||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
def generate_response(self, prompt: str, empty_system: bool = False) -> str:
|
||||
# Updated to new API format
|
||||
try:
|
||||
system_content = self.system_prompt if not empty_system else ""
|
||||
logger.debug(f"API | OpenAI | {self.model_name} | Sending request")
|
||||
|
||||
_log_prompt_details("OpenAI", self.model_name, self.power_name, prompt, system_content)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self.system_prompt if not empty_system else ""},
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
if not response or not hasattr(response, "choices") or not response.choices:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty."
|
||||
)
|
||||
if not response or not response.choices:
|
||||
logger.warning(f"API | OpenAI | {self.model_name} | Empty or invalid response")
|
||||
return ""
|
||||
return response.choices[0].message.content.strip()
|
||||
logger.debug(f"API | OpenAI | {self.model_name} | Received response of length {len(response.choices[0].message.content)}")
|
||||
content = response.choices[0].message.content.strip()
|
||||
_log_full_response("OpenAI", self.model_name, self.power_name, content)
|
||||
return content
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(
|
||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
||||
f"API | OpenAI | {self.model_name} | JSON decode error: {str(json_err)[:100]}"
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
f"API | OpenAI | {self.model_name} | Error: {str(e)[:150]}"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
|
@ -582,27 +668,38 @@ class ClaudeClient(BaseModelClient):
|
|||
def generate_response(self, prompt: str, empty_system: bool = False) -> str:
|
||||
# Updated Claude messages format
|
||||
try:
|
||||
system_content = self.system_prompt if not empty_system else ""
|
||||
|
||||
_log_prompt_details("Claude", self.model_name, self.power_name, prompt, system_content)
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=2000,
|
||||
system=self.system_prompt if not empty_system else "",
|
||||
system=system_content,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
if not response.content:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
|
||||
)
|
||||
if not response or not response.content:
|
||||
logger.warning(f"API | Claude | {self.model_name} | Empty or invalid response")
|
||||
return ""
|
||||
return response.content[0].text.strip() if response.content else ""
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(
|
||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
||||
)
|
||||
return ""
|
||||
logger.debug(f"API | Claude | {self.model_name} | Received response of length {len(response.content)}")
|
||||
|
||||
# Handle the new response format which might be a list of TextBlock objects
|
||||
if isinstance(response.content, list):
|
||||
# Extract text from each TextBlock
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, 'text'):
|
||||
content += block.text
|
||||
elif isinstance(block, dict) and 'text' in block:
|
||||
content += block['text']
|
||||
logger.debug(f"API | Claude | {self.model_name} | Extracted text from {len(response.content)} TextBlocks")
|
||||
else:
|
||||
content = response.content
|
||||
|
||||
_log_full_response("Claude", self.model_name, self.power_name, content)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"API | Claude | {self.model_name} | Error: {str(e)[:150]}")
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -616,24 +713,24 @@ class GeminiClient(BaseModelClient):
|
|||
self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
|
||||
|
||||
def generate_response(self, prompt: str, empty_system: bool = False) -> str:
|
||||
if empty_system:
|
||||
full_prompt = prompt
|
||||
else:
|
||||
full_prompt = self.system_prompt + prompt
|
||||
|
||||
try:
|
||||
system_content = self.system_prompt if not empty_system else ""
|
||||
logger.debug(f"API | Gemini | {self.model_name} | Sending request")
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=full_prompt,
|
||||
contents=system_content + prompt,
|
||||
)
|
||||
if not response or not response.text:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
|
||||
f"API | Gemini | {self.model_name} | Empty response"
|
||||
)
|
||||
return ""
|
||||
return response.text.strip()
|
||||
logger.debug(f"API | Gemini | {self.model_name} | Received response of length {len(response.text)}")
|
||||
content = response.text.strip()
|
||||
_log_full_response("Gemini", self.model_name, self.power_name, content)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.model_name}] Error in Gemini generate_response: {e}")
|
||||
logger.error(f"API | Gemini | {self.model_name} | Error: {str(e)[:150]}")
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -651,27 +748,34 @@ class DeepSeekClient(BaseModelClient):
|
|||
|
||||
def generate_response(self, prompt: str, empty_system: bool = False) -> str:
|
||||
try:
|
||||
system_content = self.system_prompt if not empty_system else ""
|
||||
logger.debug(f"API | DeepSeek | {self.model_name} | Sending request")
|
||||
|
||||
_log_prompt_details("DeepSeek", self.model_name, self.power_name, prompt, system_content)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self.system_prompt if not empty_system else ""},
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
||||
logger.debug(f"API | DeepSeek | {self.model_name} | Received response")
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] No valid response in generate_response."
|
||||
f"API | DeepSeek | {self.model_name} | No valid response"
|
||||
)
|
||||
return ""
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
if not content:
|
||||
logger.warning(f"[{self.model_name}] DeepSeek returned empty content.")
|
||||
logger.warning(f"API | DeepSeek | {self.model_name} | Empty content")
|
||||
return ""
|
||||
|
||||
|
||||
_log_full_response("DeepSeek", self.model_name, self.power_name, content)
|
||||
|
||||
try:
|
||||
json_response = json.loads(content)
|
||||
required_fields = ["message_type", "content"]
|
||||
|
|
@ -679,13 +783,13 @@ class DeepSeekClient(BaseModelClient):
|
|||
required_fields.append("recipient")
|
||||
if not all(field in json_response for field in required_fields):
|
||||
logger.error(
|
||||
f"[{self.model_name}] Missing required fields in response: {content}"
|
||||
f"API | DeepSeek | {self.model_name} | Missing fields: {_truncate_text(content, 100)}"
|
||||
)
|
||||
return ""
|
||||
return content
|
||||
except JSONDecodeError:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Response is not valid JSON: {content}"
|
||||
f"API | DeepSeek | {self.model_name} | Invalid JSON: {_truncate_text(content, 100)}"
|
||||
)
|
||||
content = content.replace("'", '"')
|
||||
try:
|
||||
|
|
@ -696,7 +800,7 @@ class DeepSeekClient(BaseModelClient):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
f"API | DeepSeek | {self.model_name} | Error: {str(e)[:150]}"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ def count_tokens(text: str) -> int:
|
|||
class ContextManager:
|
||||
"""
|
||||
Manages context size for Diplomacy game history and messages.
|
||||
Provides summarization functionality when context exceeds thresholds.
|
||||
Provides power-specific recursive summarization functionality
|
||||
when context exceeds thresholds.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -34,21 +35,37 @@ class ContextManager:
|
|||
self.message_token_threshold = message_token_threshold
|
||||
self.summary_model = summary_model
|
||||
|
||||
# Cache for summaries - prevents regenerating summaries unnecessarily
|
||||
self.phase_summary_cache = {}
|
||||
self.message_summary_cache = {}
|
||||
# Per-power tracking of summary states
|
||||
self.power_summary_states = {} # Indexed by power_name
|
||||
|
||||
# Track when we last generated summaries
|
||||
self.last_phase_summary_time = 0
|
||||
self.last_message_summary_time = 0
|
||||
# Shared phase summary state (phases are globally visible)
|
||||
self.phase_summary_state = {
|
||||
'last_summary': None, # The most recent summary of older phases
|
||||
'summarized_phases': [], # List of phase names that have been summarized
|
||||
'last_summary_time': 0, # When we last summarized
|
||||
}
|
||||
|
||||
# Cooldown period (seconds) - don't summarize more frequently than this
|
||||
self.summary_cooldown = 300 # 5 minutes
|
||||
|
||||
logger.debug(f"CONTEXT | Initialized manager with thresholds: phase={phase_token_threshold}, message={message_token_threshold}")
|
||||
|
||||
def load_summarization_prompts(self) -> Tuple[str, str]:
|
||||
def get_power_state(self, power_name):
|
||||
"""
|
||||
Load prompts for phase and message summarization.
|
||||
Returns tuple of (phase_prompt, message_prompt)
|
||||
Gets or initializes the summary state for a specific power.
|
||||
"""
|
||||
if power_name not in self.power_summary_states:
|
||||
self.power_summary_states[power_name] = {
|
||||
'last_message_summary': None, # The most recent message summary
|
||||
'summarized_messages': set(), # Set of message IDs that have been summarized
|
||||
'last_summary_time': 0, # When we last summarized messages for this power
|
||||
}
|
||||
return self.power_summary_states[power_name]
|
||||
|
||||
def load_summarization_prompts(self) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Load prompts for phase, message, and recursive summarization.
|
||||
Returns tuple of (phase_prompt, message_prompt, recursive_prompt)
|
||||
"""
|
||||
try:
|
||||
# Try to load from files
|
||||
|
|
@ -58,10 +75,14 @@ class ContextManager:
|
|||
with open("./ai_diplomacy/prompts/message_summary_prompt.txt", "r") as f:
|
||||
message_prompt = f.read().strip()
|
||||
|
||||
return phase_prompt, message_prompt
|
||||
with open("./ai_diplomacy/prompts/recursive_summary_prompt.txt", "r") as f:
|
||||
recursive_prompt = f.read().strip()
|
||||
|
||||
logger.debug("CONTEXT | Loaded summarization prompts from files")
|
||||
return phase_prompt, message_prompt, recursive_prompt
|
||||
except FileNotFoundError:
|
||||
# Return default prompts if files not found
|
||||
logger.warning("Summarization prompt files not found. Using defaults.")
|
||||
logger.warning("CONTEXT | Summarization prompt files not found, using default templates")
|
||||
|
||||
phase_prompt = """You are summarizing the history of a Diplomacy game.
|
||||
Create a concise summary that preserves all strategically relevant information about:
|
||||
|
|
@ -96,42 +117,136 @@ ORIGINAL MESSAGE HISTORY:
|
|||
{message_history}
|
||||
|
||||
SUMMARY:"""
|
||||
|
||||
recursive_prompt = """You are creating a recursive summary of a Diplomacy game's history.
|
||||
You have a previous summary of earlier events/messages and new content to incorporate.
|
||||
|
||||
Your task is to create a unified, seamless summary that:
|
||||
1. Preserves key strategic information from both sources
|
||||
2. Maintains chronological flow and logical structure
|
||||
3. Presents the most relevant information for decision-making
|
||||
4. Emphasizes developments in alliances, betrayals, and territorial control
|
||||
|
||||
PREVIOUS SUMMARY:
|
||||
{previous_summary}
|
||||
|
||||
NEW CONTENT:
|
||||
{new_content}
|
||||
|
||||
Create a unified summary that reads as a single coherent narrative while preserving critical strategic information:"""
|
||||
|
||||
return phase_prompt, message_prompt
|
||||
return phase_prompt, message_prompt, recursive_prompt
|
||||
|
||||
def generate_recursive_summary(self, previous_summary, new_content, prompt_type="recursive", power_name=None):
|
||||
"""
|
||||
Generates a recursive summary by combining previous summary with new content.
|
||||
|
||||
Args:
|
||||
previous_summary: Previous summary text
|
||||
new_content: New content to be incorporated
|
||||
prompt_type: Type of prompt to use (recursive, phase, or message)
|
||||
power_name: Name of the power for context
|
||||
|
||||
Returns:
|
||||
A new summary incorporating both previous and new content
|
||||
"""
|
||||
# Load appropriate prompt
|
||||
phase_prompt, message_prompt, recursive_prompt = self.load_summarization_prompts()
|
||||
|
||||
if prompt_type == "phase" and not previous_summary:
|
||||
# Initial phase summary
|
||||
prompt = phase_prompt.replace("{phase_history}", new_content)
|
||||
logger.debug(f"CONTEXT | SUMMARY | Creating initial phase summary with {len(new_content)} characters")
|
||||
elif prompt_type == "message" and not previous_summary:
|
||||
# Initial message summary
|
||||
prompt = message_prompt.replace("{message_history}", new_content)
|
||||
logger.debug(f"CONTEXT | SUMMARY | Creating initial message summary for {power_name} with {len(new_content)} characters")
|
||||
else:
|
||||
# Recursive summary (or phase/message with previous summary)
|
||||
prompt = recursive_prompt
|
||||
prompt = prompt.replace("{previous_summary}", previous_summary or "(No previous summary)")
|
||||
prompt = prompt.replace("{new_content}", new_content)
|
||||
logger.debug(f"CONTEXT | SUMMARY | Creating recursive {prompt_type} summary for {power_name or 'game'}")
|
||||
logger.debug(f"CONTEXT | SUMMARY | Previous summary: {len(previous_summary or '')} chars, New content: {len(new_content)} chars")
|
||||
|
||||
# Get the summary using the LLM
|
||||
summarization_client = load_model_client(self.summary_model, power_name=power_name, emptysystem=True)
|
||||
summary = summarization_client.generate_response(prompt)
|
||||
|
||||
logger.debug(f"CONTEXT | Generated {prompt_type} recursive summary ({len(summary)} chars)")
|
||||
return summary
|
||||
|
||||
def should_summarize_phases(self, phase_summaries: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Determine if phase summaries need to be condensed based on token count
|
||||
and cooldown period.
|
||||
Determine if phase summaries need to be condensed based on token count,
|
||||
cooldown period, and new content since last summarization.
|
||||
"""
|
||||
# Check if we're in cooldown period
|
||||
current_time = time.time()
|
||||
if current_time - self.last_phase_summary_time < self.summary_cooldown:
|
||||
if current_time - self.phase_summary_state['last_summary_time'] < self.summary_cooldown:
|
||||
logger.debug("CONTEXT | Phase summarization skipped (in cooldown period)")
|
||||
return False
|
||||
|
||||
# Join all summaries to count total tokens
|
||||
all_text = "\n\n".join(phase_summaries.values())
|
||||
token_count = count_tokens(all_text)
|
||||
# Get unsummarized phase content
|
||||
unsummarized_phase_names = [p for p in phase_summaries.keys()
|
||||
if p not in self.phase_summary_state['summarized_phases']
|
||||
and not p.startswith("SUMMARY_UNTIL_")]
|
||||
|
||||
return token_count > self.phase_token_threshold
|
||||
# If we have a previous summary, count its tokens
|
||||
base_token_count = 0
|
||||
if self.phase_summary_state['last_summary']:
|
||||
base_token_count = count_tokens(self.phase_summary_state['last_summary'])
|
||||
|
||||
# Count tokens in unsummarized phases
|
||||
unsummarized_text = "\n\n".join([phase_summaries[phase] for phase in unsummarized_phase_names])
|
||||
unsummarized_token_count = count_tokens(unsummarized_text)
|
||||
|
||||
# Check if total exceeds threshold
|
||||
total_token_count = base_token_count + unsummarized_token_count
|
||||
should_summarize = total_token_count > self.phase_token_threshold
|
||||
|
||||
if should_summarize:
|
||||
logger.debug(f"CONTEXT | Phase token count ({total_token_count}) exceeds threshold ({self.phase_token_threshold}), will summarize")
|
||||
logger.debug(f"CONTEXT | Phase breakdown: {base_token_count} tokens from previous summary + {unsummarized_token_count} tokens from {len(unsummarized_phase_names)} new phases")
|
||||
|
||||
return should_summarize
|
||||
|
||||
def should_summarize_messages(self, message_history: str) -> bool:
|
||||
def should_summarize_messages(self, message_history: str, power_name: str) -> bool:
|
||||
"""
|
||||
Determine if message history needs to be condensed based on token count
|
||||
and cooldown period.
|
||||
Determine if message history for a specific power needs to be condensed
|
||||
based on token count and cooldown period.
|
||||
"""
|
||||
# Get power-specific state
|
||||
power_state = self.get_power_state(power_name)
|
||||
|
||||
# Check if we're in cooldown period
|
||||
current_time = time.time()
|
||||
if current_time - self.last_message_summary_time < self.summary_cooldown:
|
||||
if current_time - power_state['last_summary_time'] < self.summary_cooldown:
|
||||
logger.debug(f"CONTEXT | Message summarization for {power_name} skipped (in cooldown period)")
|
||||
return False
|
||||
|
||||
token_count = count_tokens(message_history)
|
||||
return token_count > self.message_token_threshold
|
||||
# If we have a previous summary, count its tokens
|
||||
base_token_count = 0
|
||||
if power_state['last_message_summary']:
|
||||
base_token_count = count_tokens(power_state['last_message_summary'])
|
||||
|
||||
# Count tokens in the new content
|
||||
new_token_count = count_tokens(message_history)
|
||||
|
||||
# Check if total exceeds threshold
|
||||
total_token_count = base_token_count + new_token_count
|
||||
should_summarize = total_token_count > self.message_token_threshold
|
||||
|
||||
if should_summarize:
|
||||
logger.debug(f"CONTEXT | Message token count for {power_name} ({total_token_count}) exceeds threshold ({self.message_token_threshold}), will summarize")
|
||||
logger.debug(f"CONTEXT | Message breakdown for {power_name}: {base_token_count} tokens from previous summary + {new_token_count} tokens from new messages")
|
||||
|
||||
return should_summarize
|
||||
|
||||
def summarize_phase_history(self, phase_summaries: Dict[str, str], power_name: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Create a condensed version of phase summaries.
|
||||
Keeps the most recent phases intact and summarizes older ones.
|
||||
Create a recursively updated summary of phase history.
|
||||
Keeps recent phases intact and summarizes older ones.
|
||||
|
||||
Returns a new dictionary with condensed history.
|
||||
"""
|
||||
|
|
@ -139,84 +254,142 @@ SUMMARY:"""
|
|||
return phase_summaries
|
||||
|
||||
# Mark summarization time
|
||||
self.last_phase_summary_time = time.time()
|
||||
self.phase_summary_state['last_summary_time'] = time.time()
|
||||
|
||||
# Sort phases chronologically
|
||||
# Sort phases chronologically
|
||||
sorted_phases = sorted(phase_summaries.keys())
|
||||
|
||||
# Keep the 3 most recent phases intact
|
||||
recent_phases = sorted_phases[-3:] if len(sorted_phases) > 3 else sorted_phases
|
||||
older_phases = sorted_phases[:-3] if len(sorted_phases) > 3 else []
|
||||
# Get unsummarized phase names
|
||||
unsummarized_phase_names = [p for p in sorted_phases
|
||||
if p not in self.phase_summary_state['summarized_phases']
|
||||
and not p.startswith("SUMMARY_UNTIL_")]
|
||||
|
||||
if not older_phases:
|
||||
# Keep the 3 most recent phases intact
|
||||
recent_phases = unsummarized_phase_names[-3:] if len(unsummarized_phase_names) > 3 else unsummarized_phase_names
|
||||
phases_to_summarize = [p for p in unsummarized_phase_names if p not in recent_phases]
|
||||
|
||||
if not phases_to_summarize:
|
||||
logger.debug("CONTEXT | No new phases to summarize")
|
||||
return phase_summaries # Nothing to summarize
|
||||
|
||||
# Get summarization prompt
|
||||
phase_prompt, _ = self.load_summarization_prompts()
|
||||
# Text to summarize: previous summary + new phases to summarize
|
||||
previous_summary = self.phase_summary_state['last_summary'] or ""
|
||||
|
||||
# Generate a summary of the older phases
|
||||
older_text = ""
|
||||
for phase in older_phases:
|
||||
older_text += f"PHASE {phase}:\n{phase_summaries[phase]}\n\n"
|
||||
new_content = ""
|
||||
for phase in phases_to_summarize:
|
||||
new_content += f"PHASE {phase}:\n{phase_summaries[phase]}\n\n"
|
||||
|
||||
# Check if we already have a cached summary for this exact text
|
||||
if older_text in self.phase_summary_cache:
|
||||
summary = self.phase_summary_cache[older_text]
|
||||
# Log before summarization
|
||||
logger.info(f"CONTEXT | PHASE SUMMARIZATION | Starting recursive summarization for {len(phases_to_summarize)} phases")
|
||||
logger.info(f"CONTEXT | PHASE SUMMARIZATION | Phases being summarized: {', '.join(phases_to_summarize)}")
|
||||
|
||||
# Generate recursive summary
|
||||
if previous_summary:
|
||||
# We have a previous summary, do recursive summarization
|
||||
logger.debug(f"CONTEXT | PHASE SUMMARIZATION | Performing recursive summarization with previous summary ({len(previous_summary)} chars)")
|
||||
summary = self.generate_recursive_summary(
|
||||
previous_summary,
|
||||
new_content,
|
||||
prompt_type="recursive",
|
||||
power_name=power_name
|
||||
)
|
||||
else:
|
||||
# Generate new summary
|
||||
summarization_client = load_model_client(self.summary_model, power_name=power_name, emptysystem=True)
|
||||
formatted_prompt = phase_prompt.replace("{phase_history}", older_text)
|
||||
summary = summarization_client.generate_response(formatted_prompt)
|
||||
|
||||
# Cache the result
|
||||
self.phase_summary_cache[older_text] = summary
|
||||
# No previous summary, do initial summarization
|
||||
logger.debug(f"CONTEXT | PHASE SUMMARIZATION | Performing initial phase summarization ({len(new_content)} chars)")
|
||||
summary = self.generate_recursive_summary(
|
||||
None,
|
||||
new_content,
|
||||
prompt_type="phase",
|
||||
power_name=power_name
|
||||
)
|
||||
|
||||
# Update phase summary state
|
||||
self.phase_summary_state['last_summary'] = summary
|
||||
self.phase_summary_state['summarized_phases'].extend(phases_to_summarize)
|
||||
|
||||
# Create new dictionary with summarized older phases and intact recent phases
|
||||
result = {}
|
||||
|
||||
# Add the summary as a special entry
|
||||
summary_key = f"SUMMARY_UNTIL_{older_phases[-1]}"
|
||||
result[summary_key] = summary
|
||||
if phases_to_summarize:
|
||||
last_summarized = max(phases_to_summarize)
|
||||
summary_key = f"SUMMARY_UNTIL_{last_summarized}"
|
||||
result[summary_key] = summary
|
||||
logger.info(f"CONTEXT | PHASE SUMMARIZATION | Created summary key '{summary_key}' ({len(summary)} chars)")
|
||||
|
||||
# Add the recent phases as-is
|
||||
for phase in recent_phases:
|
||||
result[phase] = phase_summaries[phase]
|
||||
|
||||
|
||||
logger.info(f"CONTEXT | PHASE SUMMARIZATION | Recursively condensed {len(phase_summaries)} phase entries to {len(result)}")
|
||||
logger.debug(f"CONTEXT | PHASE SUMMARIZATION | Result contains summary + {len(recent_phases)} intact recent phases")
|
||||
return result
|
||||
|
||||
def summarize_message_history(
|
||||
self,
|
||||
message_history: str,
|
||||
power_name: Optional[str] = None,
|
||||
power_name: str,
|
||||
organized_by_relationship: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a condensed version of message history.
|
||||
If organized_by_relationship is True, assumes the history is already
|
||||
organized by power relationships.
|
||||
Create a recursively updated summary of message history for a specific power.
|
||||
|
||||
Returns a condensed message history.
|
||||
Args:
|
||||
message_history: Current unsummarized message history
|
||||
power_name: The power whose history is being summarized
|
||||
organized_by_relationship: If True, assumes messages are organized by relationship
|
||||
|
||||
Returns:
|
||||
Updated message history with recursive summarization applied
|
||||
"""
|
||||
if not self.should_summarize_messages(message_history):
|
||||
if not self.should_summarize_messages(message_history, power_name):
|
||||
return message_history
|
||||
|
||||
# Get power-specific state
|
||||
power_state = self.get_power_state(power_name)
|
||||
|
||||
# Mark summarization time
|
||||
self.last_message_summary_time = time.time()
|
||||
power_state['last_summary_time'] = time.time()
|
||||
|
||||
# Get summarization prompt
|
||||
_, message_prompt = self.load_summarization_prompts()
|
||||
# Log before summarization
|
||||
logger.info(f"CONTEXT | MESSAGE SUMMARIZATION | Starting message summarization for {power_name}")
|
||||
logger.info(f"CONTEXT | MESSAGE SUMMARIZATION | Current message history size: {len(message_history)} chars")
|
||||
|
||||
# Check if we already have a cached summary for this exact text
|
||||
if message_history in self.message_summary_cache:
|
||||
return self.message_summary_cache[message_history]
|
||||
# Generate recursive summary
|
||||
previous_summary = power_state['last_message_summary']
|
||||
|
||||
if previous_summary:
|
||||
# We have a previous summary, do recursive summarization
|
||||
logger.debug(f"CONTEXT | MESSAGE SUMMARIZATION | Performing recursive message summarization for {power_name}")
|
||||
logger.debug(f"CONTEXT | MESSAGE SUMMARIZATION | Previous summary: {len(previous_summary)} chars, New messages: {len(message_history)} chars")
|
||||
|
||||
# Generate new summary
|
||||
summarization_client = load_model_client(self.summary_model, power_name=power_name, emptysystem=True)
|
||||
formatted_prompt = message_prompt.replace("{message_history}", message_history)
|
||||
summary = summarization_client.generate_response(formatted_prompt)
|
||||
summary = self.generate_recursive_summary(
|
||||
previous_summary,
|
||||
message_history,
|
||||
prompt_type="recursive",
|
||||
power_name=power_name
|
||||
)
|
||||
else:
|
||||
# No previous summary, do initial summarization
|
||||
logger.debug(f"CONTEXT | MESSAGE SUMMARIZATION | Performing initial message summarization for {power_name} ({len(message_history)} chars)")
|
||||
|
||||
summary = self.generate_recursive_summary(
|
||||
None,
|
||||
message_history,
|
||||
prompt_type="message",
|
||||
power_name=power_name
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
self.message_summary_cache[message_history] = summary
|
||||
# Update power state
|
||||
power_state['last_message_summary'] = summary
|
||||
|
||||
# Track metrics for logging
|
||||
message_tokens = count_tokens(message_history)
|
||||
summary_tokens = count_tokens(summary)
|
||||
reduction = 100 - (summary_tokens * 100 / message_tokens) if message_tokens > 0 else 0
|
||||
|
||||
logger.info(f"CONTEXT | MESSAGE SUMMARIZATION | Completed for {power_name}: {message_tokens} → {summary_tokens} tokens ({reduction:.1f}% reduction)")
|
||||
logger.debug(f"CONTEXT | MESSAGE SUMMARIZATION | Original size: {len(message_history)} chars, Summary size: {len(summary)} chars")
|
||||
|
||||
return summary
|
||||
|
||||
|
|
@ -228,16 +401,34 @@ SUMMARY:"""
|
|||
"""
|
||||
Main access point for getting optimized phase summaries.
|
||||
If summaries are below threshold, returns original.
|
||||
Otherwise, returns condensed version.
|
||||
Otherwise, returns recursively condensed version.
|
||||
"""
|
||||
if not hasattr(game, "phase_summaries") or not game.phase_summaries:
|
||||
logger.debug("CONTEXT | No phase summaries available")
|
||||
return {}
|
||||
|
||||
logger.debug(f"CONTEXT | Checking phase optimization for {power_name or 'game'} with {len(game.phase_summaries)} phases")
|
||||
|
||||
if self.should_summarize_phases(game.phase_summaries):
|
||||
# Create condensed version
|
||||
return self.summarize_phase_history(game.phase_summaries, power_name)
|
||||
# Create condensed version using recursive summarization
|
||||
logger.debug(f"CONTEXT | Creating optimized phase summaries for {power_name or 'game'}")
|
||||
result = self.summarize_phase_history(game.phase_summaries, power_name)
|
||||
|
||||
# Add a log showing which phases are included in the optimized version
|
||||
phase_keys = list(result.keys())
|
||||
summary_keys = [k for k in phase_keys if k.startswith("SUMMARY_UNTIL_")]
|
||||
regular_phases = [k for k in phase_keys if not k.startswith("SUMMARY_UNTIL_")]
|
||||
|
||||
logger.info(f"CONTEXT | PHASE OPTIMIZATION | Returning {len(summary_keys)} summary entries and {len(regular_phases)} regular phases")
|
||||
if summary_keys:
|
||||
logger.debug(f"CONTEXT | PHASE OPTIMIZATION | Summary entries: {', '.join(summary_keys)}")
|
||||
if regular_phases:
|
||||
logger.debug(f"CONTEXT | PHASE OPTIMIZATION | Regular phases: {', '.join(regular_phases)}")
|
||||
|
||||
return result
|
||||
else:
|
||||
# Return original
|
||||
logger.debug("CONTEXT | Using original phase summaries (below threshold)")
|
||||
return game.phase_summaries
|
||||
|
||||
def get_optimized_message_history(
|
||||
|
|
@ -257,7 +448,11 @@ SUMMARY:"""
|
|||
Returns:
|
||||
Optimized message history as string
|
||||
"""
|
||||
# Get the raw message history
|
||||
if not power_name:
|
||||
logger.warning("CONTEXT | No power_name provided for message history optimization, using raw history")
|
||||
return organized_history or (game_history.get_game_history() if hasattr(game_history, "get_game_history") else str(game_history))
|
||||
|
||||
# Get the raw message history for this power
|
||||
if organized_history is not None:
|
||||
message_history = organized_history
|
||||
elif hasattr(game_history, "get_game_history"):
|
||||
|
|
@ -265,11 +460,28 @@ SUMMARY:"""
|
|||
else:
|
||||
message_history = str(game_history) if game_history else "(No history yet)"
|
||||
|
||||
if self.should_summarize_messages(message_history):
|
||||
# Create condensed version
|
||||
return self.summarize_message_history(message_history, power_name)
|
||||
if message_history == "(No history yet)":
|
||||
return message_history
|
||||
|
||||
logger.debug(f"CONTEXT | Checking message optimization for {power_name} with {len(message_history)} chars")
|
||||
|
||||
# Check if we need to create a recursive summary
|
||||
if self.should_summarize_messages(message_history, power_name):
|
||||
# Create recursively condensed version
|
||||
logger.debug(f"CONTEXT | Creating optimized message history for {power_name}")
|
||||
result = self.summarize_message_history(message_history, power_name)
|
||||
|
||||
# Log the optimization stats
|
||||
power_state = self.get_power_state(power_name)
|
||||
has_previous_summary = power_state['last_message_summary'] is not None
|
||||
|
||||
logger.info(f"CONTEXT | MESSAGE OPTIMIZATION | {power_name} | Original size: {len(message_history)} chars, Optimized size: {len(result)} chars")
|
||||
logger.info(f"CONTEXT | MESSAGE OPTIMIZATION | {power_name} | Using {'recursive' if has_previous_summary else 'initial'} message summary")
|
||||
|
||||
return result
|
||||
else:
|
||||
# Return original
|
||||
logger.debug(f"CONTEXT | Using original message history for {power_name} (below threshold)")
|
||||
return message_history
|
||||
|
||||
|
||||
|
|
@ -287,6 +499,7 @@ def configure_context_manager(
|
|||
Should be called early in the application lifecycle.
|
||||
"""
|
||||
global context_manager
|
||||
logger.info(f"CONTEXT | Configuring manager with thresholds: phase={phase_threshold}, message={message_threshold}, model={summary_model}")
|
||||
context_manager = ContextManager(
|
||||
phase_token_threshold=phase_threshold,
|
||||
message_token_threshold=message_threshold,
|
||||
|
|
@ -305,9 +518,18 @@ def get_optimized_context(
|
|||
Returns:
|
||||
Tuple of (optimized_phase_summaries, optimized_message_history)
|
||||
"""
|
||||
logger.debug(f"CONTEXT | Getting optimized context for {power_name or 'game'}")
|
||||
|
||||
optimized_phases = context_manager.get_optimized_phase_summaries(game, power_name)
|
||||
optimized_messages = context_manager.get_optimized_message_history(
|
||||
game_history, power_name, organized_history
|
||||
)
|
||||
|
||||
# Add a log entry showing what we're returning for tracking
|
||||
phase_count = len(optimized_phases) if optimized_phases else 0
|
||||
message_length = len(optimized_messages) if optimized_messages else 0
|
||||
summary_count = len([k for k in optimized_phases.keys() if k.startswith("SUMMARY_UNTIL_")]) if optimized_phases else 0
|
||||
|
||||
logger.info(f"CONTEXT | OPTIMIZATION RESULT | {power_name or 'game'} | Returning {phase_count} phases ({summary_count} summaries) and {message_length} chars of messages")
|
||||
|
||||
return optimized_phases, optimized_messages
|
||||
|
|
@ -21,12 +21,19 @@ def load_model_client(model_id: str, power_name: Optional[str] = None, emptysyst
|
|||
from .clients import ClaudeClient, GeminiClient, DeepSeekClient, OpenAIClient
|
||||
|
||||
lower_id = model_id.lower()
|
||||
|
||||
logger.debug(f"MODEL | Loading client for {model_id}{' for ' + power_name if power_name else ''}{' with empty system' if emptysystem else ''}")
|
||||
|
||||
if "claude" in lower_id:
|
||||
logger.debug(f"MODEL | Selected Claude client for {model_id}")
|
||||
return ClaudeClient(model_id, power_name, emptysystem=emptysystem)
|
||||
elif "gemini" in lower_id:
|
||||
logger.debug(f"MODEL | Selected Gemini client for {model_id}")
|
||||
return GeminiClient(model_id, power_name, emptysystem=emptysystem)
|
||||
elif "deepseek" in lower_id:
|
||||
logger.debug(f"MODEL | Selected DeepSeek client for {model_id}")
|
||||
return DeepSeekClient(model_id, power_name, emptysystem=emptysystem)
|
||||
else:
|
||||
# Default to OpenAI
|
||||
return OpenAIClient(model_id, power_name)
|
||||
logger.debug(f"MODEL | Selected OpenAI client for {model_id}")
|
||||
return OpenAIClient(model_id, power_name, emptysystem=emptysystem)
|
||||
|
|
@ -20,7 +20,7 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
Each power can send up to 'max_rounds' messages, choosing between private
|
||||
and global messages each turn.
|
||||
"""
|
||||
logger.info("Starting negotiation phase.")
|
||||
logger.info(f"DIPLOMACY | Starting negotiation phase with {max_rounds} rounds")
|
||||
|
||||
# Conversation messages are kept in a local list ONLY to build conversation_so_far text.
|
||||
conversation_messages = []
|
||||
|
|
@ -28,9 +28,12 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
logger.debug(f"DIPLOMACY | Found {len(active_powers)} active powers for negotiations")
|
||||
|
||||
# We do up to 'max_rounds' single-message turns for each power
|
||||
for round_index in range(max_rounds):
|
||||
logger.debug(f"DIPLOMACY | Starting round {round_index+1}/{max_rounds}")
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
) as executor:
|
||||
|
|
@ -40,10 +43,11 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
client = load_model_client(model_id)
|
||||
possible_orders = gather_possible_orders(game, power_name)
|
||||
if not possible_orders:
|
||||
logger.info(f"No orderable locations for {power_name}; skipping.")
|
||||
logger.info(f"DIPLOMACY | {power_name} | No orderable locations, skipping negotiation")
|
||||
continue
|
||||
board_state = game.get_state()
|
||||
|
||||
logger.debug(f"DIPLOMACY | {power_name} | Requesting conversation response from {model_id}")
|
||||
future = executor.submit(
|
||||
client.get_conversation_reply,
|
||||
game,
|
||||
|
|
@ -57,14 +61,21 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
)
|
||||
|
||||
futures[future] = power_name
|
||||
logger.debug(f"Submitted get_conversation_reply task for {power_name}.")
|
||||
|
||||
message_count = 0
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
power_name = futures[future]
|
||||
try:
|
||||
messages = future.result()
|
||||
if messages:
|
||||
logger.debug(f"DIPLOMACY | {power_name} | Generated {len(messages)} messages")
|
||||
for message in messages:
|
||||
recipient = message["recipient"]
|
||||
msg_type = "global" if recipient == GLOBAL else "private"
|
||||
truncated_content = message["content"][:50] + ("..." if len(message["content"]) > 50 else "")
|
||||
|
||||
logger.debug(f"DIPLOMACY | {power_name} → {recipient} | {msg_type.upper()} | {truncated_content}")
|
||||
|
||||
# Create an official message in the Diplomacy engine
|
||||
diplo_message = Message(
|
||||
phase=game.current_short_phase,
|
||||
|
|
@ -80,11 +91,20 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
message["content"],
|
||||
)
|
||||
conversation_messages.append(message)
|
||||
message_count += 1
|
||||
else:
|
||||
logger.debug(f"No valid messages returned for {power_name}.")
|
||||
logger.warning(f"DIPLOMACY | {power_name} | No valid messages generated")
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
except Exception as exc:
|
||||
logger.error(f"LLM request failed for {power_name}: {exc}")
|
||||
error_msg = str(exc)[:150]
|
||||
logger.error(f"DIPLOMACY | {power_name} | Request failed: {error_msg}")
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
|
||||
logger.debug(f"DIPLOMACY | Round {round_index+1} completed with {message_count} new messages")
|
||||
|
||||
logger.info("Negotiation phase complete.")
|
||||
total_messages = len(conversation_messages)
|
||||
global_msgs = sum(1 for m in conversation_messages if m["recipient"] == GLOBAL)
|
||||
private_msgs = total_messages - global_msgs
|
||||
|
||||
logger.info(f"DIPLOMACY | Negotiation complete: {total_messages} messages ({global_msgs} global, {private_msgs} private)")
|
||||
return conversation_messages
|
||||
|
|
|
|||
|
|
@ -9,15 +9,16 @@ Create a concise summary of the conversations between powers that preserves:
|
|||
Organize by relationships (e.g., FRANCE-GERMANY, ENGLAND-RUSSIA), prioritizing the most
|
||||
significant interactions. Include specific territory names mentioned.
|
||||
|
||||
In your summary, maintain all of the following critical diplomatic information:
|
||||
- Specific agreements about attacking or supporting certain territories
|
||||
- Promises of non-aggression and their scope/duration
|
||||
- Discussions about supply center control and transfers
|
||||
- Stated preferences about other powers (who they want to attack/support)
|
||||
- Explicit lies or deceptions that were revealed
|
||||
- Coordination of moves between powers
|
||||
The summary must reflect the actual diplomatic landscape accurately so players can make informed decisions.
|
||||
When summarizing, preserve:
|
||||
- Any explicit commitments to support or attack specific territories
|
||||
- Agreements about dividing Supply Centers or territories
|
||||
- Offers of non-aggression or alliance
|
||||
- Signals of intent or warnings about future actions
|
||||
- Discussions of shared enemies or mutual threats
|
||||
- Any deceptions or misleading statements that could indicate future betrayal
|
||||
|
||||
The summary must reflect the actual diplomatic landscape accurately so players can make informed decisions and remember past interactions that might influence current negotiations.
|
||||
Be precise about which powers communicated what information to whom, as this impacts trust and future decision-making.
|
||||
|
||||
ORIGINAL MESSAGE HISTORY:
|
||||
{message_history}
|
||||
|
|
|
|||
16
ai_diplomacy/prompts/recursive_summary_prompt.txt
Normal file
16
ai_diplomacy/prompts/recursive_summary_prompt.txt
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
You are creating a recursive summary of a Diplomacy game's history.
|
||||
You have a previous summary of earlier events/messages and new content to incorporate.
|
||||
|
||||
Your task is to create a unified, seamless summary that:
|
||||
1. Preserves key strategic information from both sources
|
||||
2. Maintains chronological flow and logical structure
|
||||
3. Presents the most relevant information for decision-making
|
||||
4. Emphasizes developments in alliances, betrayals, and territorial control
|
||||
|
||||
PREVIOUS SUMMARY:
|
||||
{previous_summary}
|
||||
|
||||
NEW CONTENT:
|
||||
{new_content}
|
||||
|
||||
Create a unified summary that reads as a single coherent narrative while preserving critical strategic information:
|
||||
273
ai_diplomacy/test_recursive_summarization.py
Normal file
273
ai_diplomacy/test_recursive_summarization.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to validate the recursive summarization functionality in long_story_short.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger("test_summarization")
|
||||
|
||||
# Add the parent directory to the path so we can import the module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import the module we want to test
|
||||
from ai_diplomacy.long_story_short import (
|
||||
ContextManager,
|
||||
configure_context_manager,
|
||||
get_optimized_context
|
||||
)
|
||||
|
||||
def test_phase_summarization():
|
||||
"""
|
||||
Test the recursive phase summarization functionality
|
||||
"""
|
||||
logger.info("Testing phase summarization...")
|
||||
|
||||
# Create a context manager with a very low threshold to force summarization
|
||||
cm = ContextManager(
|
||||
phase_token_threshold=200, # Very low to trigger summarization
|
||||
message_token_threshold=1000,
|
||||
summary_model="o3-mini" # Use a simple model for testing
|
||||
)
|
||||
|
||||
# Create a mock game object with phase summaries
|
||||
mock_game = type('MockGame', (), {})()
|
||||
mock_game.phase_summaries = {
|
||||
"S1901M": "Spring 1901 Movement: France moved to Burgundy. Germany attacked Paris but was repelled. Russia and Turkey formed an alliance against Austria.",
|
||||
"F1901M": "Fall 1901 Movement: England took Norway. Russia captured Sweden. Italy moved into Tyrolia threatening Vienna.",
|
||||
"W1901A": "Winter 1901 Adjustments: France built F Brest. England built F London. Russia built A Moscow.",
|
||||
"S1902M": "Spring 1902 Movement: Germany attacked Burgundy with support from Munich. Italy took Vienna from Austria. Turkey moved into Armenia threatening Russia."
|
||||
}
|
||||
|
||||
# Get optimized summaries
|
||||
optimized_phases = cm.get_optimized_phase_summaries(mock_game, power_name="FRANCE")
|
||||
|
||||
# Print the result
|
||||
logger.info(f"Original phases: {len(mock_game.phase_summaries)}")
|
||||
logger.info(f"Optimized phases: {len(optimized_phases)}")
|
||||
|
||||
# Now add more phases to trigger another round of summarization
|
||||
mock_game.phase_summaries.update({
|
||||
"F1902M": "Fall 1902 Movement: France retook Burgundy. England invaded St. Petersburg. Austria was eliminated by combined Italian and Turkish forces.",
|
||||
"W1902A": "Winter 1902 Adjustments: Russia disbanded A Warsaw. Italy built A Rome and F Naples. Turkey built F Smyrna.",
|
||||
"S1903M": "Spring 1903 Movement: Germany and France formed an alliance against England. Russia's position in the north collapsed."
|
||||
})
|
||||
|
||||
# Get optimized summaries again - should see recursive summarization
|
||||
new_optimized_phases = cm.get_optimized_phase_summaries(mock_game, power_name="FRANCE")
|
||||
|
||||
logger.info(f"Updated original phases: {len(mock_game.phase_summaries)}")
|
||||
logger.info(f"New optimized phases: {len(new_optimized_phases)}")
|
||||
|
||||
# Print summary content
|
||||
for key, summary in new_optimized_phases.items():
|
||||
if key.startswith("SUMMARY_UNTIL_"):
|
||||
logger.info(f"=== {key} ===")
|
||||
logger.info(summary[:200] + "..." if len(summary) > 200 else summary)
|
||||
|
||||
# Add a third batch to trigger recursive summarization of the previous summary
|
||||
mock_game.phase_summaries.update({
|
||||
"F1903M": "Fall 1903 Movement: Italy captured Tunis. France took Belgium with German support. England lost Edinburgh to a combined Russian and German attack.",
|
||||
"W1903A": "Winter 1903 Adjustments: England disbanded F London. France built A Paris. Germany built F Kiel.",
|
||||
"S1904M": "Spring 1904 Movement: Turkey attacked Italy in the Ionian Sea. Russia and Germany continued their assault on England. France moved towards Spain."
|
||||
})
|
||||
|
||||
# Get optimized summaries a third time - should see further recursive summarization
|
||||
final_optimized_phases = cm.get_optimized_phase_summaries(mock_game, power_name="FRANCE")
|
||||
|
||||
logger.info(f"Final original phases: {len(mock_game.phase_summaries)}")
|
||||
logger.info(f"Final optimized phases: {len(final_optimized_phases)}")
|
||||
|
||||
# Print final summary content
|
||||
for key, summary in final_optimized_phases.items():
|
||||
if key.startswith("SUMMARY_UNTIL_"):
|
||||
logger.info(f"=== {key} ===")
|
||||
logger.info(summary[:200] + "..." if len(summary) > 200 else summary)
|
||||
|
||||
return optimized_phases, new_optimized_phases, final_optimized_phases
|
||||
|
||||
def test_message_summarization():
|
||||
"""
|
||||
Test the power-specific message summarization functionality
|
||||
"""
|
||||
logger.info("Testing message summarization...")
|
||||
|
||||
# Create a context manager with a very low threshold to force summarization
|
||||
cm = ContextManager(
|
||||
phase_token_threshold=1000,
|
||||
message_token_threshold=200, # Very low to trigger summarization
|
||||
summary_model="o3-mini" # Use a simple model for testing
|
||||
)
|
||||
|
||||
# Test with multiple powers
|
||||
powers = ["FRANCE", "GERMANY", "ENGLAND"]
|
||||
|
||||
# Create mock message histories for each power
|
||||
messages = {
|
||||
"FRANCE": "FROM: FRANCE, TO: GERMANY\nI propose we ally against England. I'll support your move to Belgium if you don't move to Burgundy.\n\nFROM: GERMANY, TO: FRANCE\nAgreed. I won't move to Burgundy. Let's coordinate against England.\n\nFROM: ENGLAND, TO: GLOBAL\nI'm looking for allies against France. Any takers?",
|
||||
"GERMANY": "FROM: FRANCE, TO: GERMANY\nI propose we ally against England. I'll support your move to Belgium if you don't move to Burgundy.\n\nFROM: GERMANY, TO: FRANCE\nAgreed. I won't move to Burgundy. Let's coordinate against England.\n\nFROM: GERMANY, TO: RUSSIA\nI suggest we avoid conflict in Sweden and focus on other directions.",
|
||||
"ENGLAND": "FROM: ENGLAND, TO: GLOBAL\nI'm looking for allies against France. Any takers?\n\nFROM: RUSSIA, TO: ENGLAND\nI could support you against France if you help me with Germany.\n\nFROM: ENGLAND, TO: RUSSIA\nThat works for me. I'll help you take Denmark if you support me into the English Channel."
|
||||
}
|
||||
|
||||
# Test for each power
|
||||
results = {}
|
||||
for power in powers:
|
||||
logger.info(f"Testing message summarization for {power}...")
|
||||
|
||||
# Get optimized message history for this power
|
||||
optimized_messages = cm.get_optimized_message_history(messages[power], power)
|
||||
|
||||
logger.info(f"Original message length: {len(messages[power])}")
|
||||
logger.info(f"Optimized message length: {len(optimized_messages)}")
|
||||
|
||||
# Now add more messages to trigger recursive summarization
|
||||
additional_messages = f"\n\nFROM: {power}, TO: GLOBAL\nI declare that I am focusing on defense this turn.\n\nFROM: ITALY, TO: {power}\nI propose a mutual non-aggression pact."
|
||||
combined_messages = additional_messages * 3 # Multiply to ensure we exceed threshold
|
||||
|
||||
# Get optimized messages again with the combined content
|
||||
new_optimized_messages = cm.get_optimized_message_history(
|
||||
optimized_messages + combined_messages,
|
||||
power
|
||||
)
|
||||
|
||||
logger.info(f"Updated original + new message length: {len(optimized_messages + combined_messages)}")
|
||||
logger.info(f"Recursive optimized message length: {len(new_optimized_messages)}")
|
||||
|
||||
# Add a third batch to trigger recursive summarization of the previous summary
|
||||
more_messages = f"\n\nFROM: TURKEY, TO: {power}\nI suggest we coordinate our moves in the Mediterranean.\n\nFROM: {power}, TO: TURKEY\nI agree to non-aggression in the Mediterranean. Let's focus on other targets."
|
||||
third_combined = more_messages * 4 # Multiply to ensure we exceed threshold again
|
||||
|
||||
final_optimized_messages = cm.get_optimized_message_history(
|
||||
new_optimized_messages + third_combined,
|
||||
power
|
||||
)
|
||||
|
||||
logger.info(f"Final combined message length: {len(new_optimized_messages + third_combined)}")
|
||||
logger.info(f"Final optimized message length: {len(final_optimized_messages)}")
|
||||
|
||||
results[power] = (optimized_messages, new_optimized_messages, final_optimized_messages)
|
||||
|
||||
return results
|
||||
|
||||
def test_with_game_integration():
|
||||
"""
|
||||
Test using the get_optimized_context function which is what the game actually uses
|
||||
"""
|
||||
logger.info("Testing integration with game context...")
|
||||
|
||||
# Configure the global context manager with very low thresholds
|
||||
configure_context_manager(
|
||||
phase_threshold=200,
|
||||
message_threshold=200,
|
||||
summary_model="o3-mini"
|
||||
)
|
||||
|
||||
# Create a mock game object with phase summaries
|
||||
mock_game = type('MockGame', (), {})()
|
||||
mock_game.phase_summaries = {
|
||||
"S1901M": "Spring 1901 Movement: France moved to Burgundy. Germany attacked Paris but was repelled.",
|
||||
"F1901M": "Fall 1901 Movement: England took Norway. Russia captured Sweden.",
|
||||
"W1901A": "Winter 1901 Adjustments: France built F Brest. England built F London.",
|
||||
"S1902M": "Spring 1902 Movement: Germany attacked Burgundy with support from Munich."
|
||||
}
|
||||
|
||||
# Create mock message histories
|
||||
mock_messages = {
|
||||
"FRANCE": "FROM: FRANCE, TO: GERMANY\nI propose we ally against England.\n\nFROM: GERMANY, TO: FRANCE\nAgreed. Let's coordinate against England.",
|
||||
"GERMANY": "FROM: FRANCE, TO: GERMANY\nI propose we ally against England.\n\nFROM: GERMANY, TO: FRANCE\nAgreed. Let's coordinate against England.",
|
||||
}
|
||||
|
||||
# Create a mock game history object
|
||||
class MockGameHistory:
|
||||
def get_game_history(self, power_name=None):
|
||||
if power_name:
|
||||
return mock_messages.get(power_name, "")
|
||||
return ""
|
||||
|
||||
mock_history = MockGameHistory()
|
||||
|
||||
# Test for multiple powers
|
||||
for power in ["FRANCE", "GERMANY"]:
|
||||
logger.info(f"Testing integration for {power}...")
|
||||
|
||||
# First call - should be under threshold
|
||||
optimized_phases, optimized_messages = get_optimized_context(
|
||||
mock_game, mock_history, power_name=power
|
||||
)
|
||||
|
||||
logger.info(f"{power} initial optimized_phases count: {len(optimized_phases)}")
|
||||
logger.info(f"{power} initial optimized_messages length: {len(optimized_messages)}")
|
||||
|
||||
# Add more content to exceed thresholds
|
||||
mock_game.phase_summaries.update({
|
||||
"F1902M": "Fall 1902 Movement: France retook Burgundy. England invaded St. Petersburg.",
|
||||
"W1902A": "Winter 1902 Adjustments: Russia disbanded A Warsaw. Italy built A Rome.",
|
||||
"S1903M": "Spring 1903 Movement: Germany and France formed an alliance against England."
|
||||
})
|
||||
|
||||
# Update message history
|
||||
additional_msg = f"\n\nFROM: {power}, TO: GLOBAL\nI declare that I am focusing on defense this turn." * 3
|
||||
mock_messages[power] += additional_msg
|
||||
|
||||
# Second call - should trigger summarization
|
||||
optimized_phases2, optimized_messages2 = get_optimized_context(
|
||||
mock_game, mock_history, power_name=power
|
||||
)
|
||||
|
||||
logger.info(f"{power} second optimized_phases count: {len(optimized_phases2)}")
|
||||
if len(optimized_phases2) < len(mock_game.phase_summaries):
|
||||
logger.info(f"✅ {power} phase summarization successful!")
|
||||
else:
|
||||
logger.warning(f"❌ {power} phase summarization did not occur as expected")
|
||||
|
||||
logger.info(f"{power} second optimized_messages length: {len(optimized_messages2)}")
|
||||
if len(optimized_messages2) < len(mock_messages[power]):
|
||||
logger.info(f"✅ {power} message summarization successful!")
|
||||
else:
|
||||
logger.warning(f"❌ {power} message summarization did not occur as expected")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main test function
|
||||
"""
|
||||
logger.info("Starting recursive summarization tests...")
|
||||
|
||||
# Test phase summarization
|
||||
phase_test_results = test_phase_summarization()
|
||||
|
||||
# Test message summarization
|
||||
message_test_results = test_message_summarization()
|
||||
|
||||
# Test integration with game context
|
||||
integration_result = test_with_game_integration()
|
||||
|
||||
logger.info("All tests completed!")
|
||||
|
||||
# Final validation checks
|
||||
# Check if we're getting different summaries for different powers (should be!)
|
||||
message_summaries = {power: result[1] for power, result in message_test_results.items()}
|
||||
unique_summaries = set(message_summaries.values())
|
||||
|
||||
logger.info(f"Number of powers tested: {len(message_summaries)}")
|
||||
logger.info(f"Number of unique message summaries: {len(unique_summaries)}")
|
||||
|
||||
if len(unique_summaries) == len(message_summaries):
|
||||
logger.info("✅ SUCCESS: Each power has a unique message summary!")
|
||||
else:
|
||||
logger.warning("❌ FAILURE: Some powers have identical message summaries!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -28,7 +28,7 @@ def assign_models_to_powers(randomize=True):
|
|||
"o3-mini",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite-preview-02-05",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"gemini-1.5-pro",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-5-haiku-20241022",
|
||||
|
|
@ -46,8 +46,10 @@ def assign_models_to_powers(randomize=True):
|
|||
model = random.choice(available_models)
|
||||
available_models.remove(model)
|
||||
result[power] = model
|
||||
logger.debug(f"CONFIG | Generated randomized power-model mapping for {len(POWERS)} powers")
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"CONFIG | Using fixed power-model mapping with {len(model_list)} models")
|
||||
return {
|
||||
power: model_list[i] for i, power in enumerate(POWERS)
|
||||
}
|
||||
|
|
@ -63,6 +65,9 @@ def gather_possible_orders(game, power_name):
|
|||
result = {}
|
||||
for loc in orderable_locs:
|
||||
result[loc] = all_possible.get(loc, [])
|
||||
|
||||
order_count = sum(len(orders) for orders in result.values())
|
||||
logger.debug(f"ORDERS | {power_name} | Found {len(result)} orderable locations with {order_count} total possible orders")
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -81,8 +86,11 @@ def get_valid_orders(
|
|||
If invalid, we append the error feedback to the conversation
|
||||
context for the next retry. If still invalid, return fallback.
|
||||
"""
|
||||
# Track invalid orders for feedback
|
||||
invalid_info = []
|
||||
|
||||
# Ask the LLM for orders
|
||||
logger.debug(f"ORDERS | {power_name} | Requesting orders from {client.model_name}")
|
||||
orders = client.get_orders(
|
||||
game=game,
|
||||
board_state=board_state,
|
||||
|
|
@ -115,12 +123,14 @@ def get_valid_orders(
|
|||
|
||||
if validity == 1:
|
||||
# All orders are fully valid
|
||||
logger.debug(f"ORDERS | {power_name} | Validated {len(orders)} orders successfully")
|
||||
return orders
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{power_name}] failed to produce a valid order, using fallback."
|
||||
f"ORDERS | {power_name} | Failed validation: '{move}' is invalid"
|
||||
)
|
||||
model_error_stats[power_name]["order_decoding_errors"] += 1
|
||||
logger.debug(f"ORDERS | {power_name} | Using fallback orders")
|
||||
fallback = client.fallback_orders(possible_orders)
|
||||
return fallback
|
||||
|
||||
|
|
@ -355,6 +365,13 @@ def format_possible_orders(game, possible_orders):
|
|||
for order, desc in strategic_orders["SUPPORT"]:
|
||||
output += f" {order} {desc}\n"
|
||||
|
||||
# Log order counts for debugging
|
||||
logger.debug(f"ORDERS | Strategic classification: " +
|
||||
f"Offensive: {len(strategic_orders['OFFENSIVE'])}, " +
|
||||
f"Defensive: {len(strategic_orders['DEFENSIVE'])}, " +
|
||||
f"Tactical: {len(strategic_orders['TACTICAL'])}, " +
|
||||
f"Support: {len(strategic_orders['SUPPORT'])}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -433,28 +450,42 @@ def format_convoy_paths(game, convoy_paths_possible, power_name):
|
|||
# check if convoy_paths_possible is empty dictionary or list or none
|
||||
output = ""
|
||||
if not convoy_paths_possible:
|
||||
output = "CONVOY POSSIBILITIES: None currently available.\n"
|
||||
return output
|
||||
return "CONVOY POSSIBILITIES: None currently available.\n"
|
||||
|
||||
|
||||
# Get unit ownership for identifying our convoys vs others
|
||||
# Get our units and all other powers' units
|
||||
our_units = set(game.get_units(power_name))
|
||||
our_unit_locs = {unit[2:5] for unit in our_units}
|
||||
|
||||
# Get all powers' units and centers for context
|
||||
power_units = {}
|
||||
power_centers = {}
|
||||
for pwr in game.powers:
|
||||
power_units[pwr] = {unit[2:5] for unit in game.get_units(pwr)}
|
||||
power_centers[pwr] = set(game.get_centers(pwr))
|
||||
|
||||
# Group convoys by region and relevance
|
||||
# Organize convoys by strategic relationship
|
||||
convoys = {
|
||||
"YOUR ARMY CONVOYS": [], # Convoys using your armies
|
||||
"YOUR FLEET CONVOYS": [], # Convoys using your fleets
|
||||
"ENEMY CONVOYS": [] # Convoys you should watch for
|
||||
"YOUR CONVOYS": [], # Convoys using your armies
|
||||
"CONVOYS YOU CAN ENABLE": [], # Using your fleets to help others
|
||||
"ALLIED OPPORTUNITIES": [], # Convoys that could help contain common enemies
|
||||
"THREATS TO WATCH": [] # Convoys that could threaten your positions
|
||||
}
|
||||
|
||||
# Define major sea regions for better organization
|
||||
# Make sea regions more readable
|
||||
sea_regions = {
|
||||
'NTH': "North Sea",
|
||||
'MAO': "Mid-Atlantic",
|
||||
'TYS': "Tyrrhenian Sea",
|
||||
'BLA': "Black Sea",
|
||||
'SKA': "Skagerrak",
|
||||
'ION': "Ionian Sea",
|
||||
'EAS': "Eastern Mediterranean",
|
||||
'WES': "Western Mediterranean",
|
||||
'BAL': "Baltic Sea",
|
||||
'BOT': "Gulf of Bothnia",
|
||||
'ADR': "Adriatic Sea",
|
||||
'AEG': "Aegean Sea",
|
||||
'ENG': "English Channel"
|
||||
}
|
||||
|
||||
for start, fleets, destinations in convoy_paths_possible:
|
||||
|
|
@ -462,38 +493,77 @@ def format_convoy_paths(game, convoy_paths_possible, power_name):
|
|||
if not destinations or not fleets:
|
||||
continue
|
||||
|
||||
# Determine if this is our army that could be convoyed
|
||||
is_our_army = start in our_unit_locs
|
||||
|
||||
# Determine if these are our fleets that could convoy
|
||||
is_our_fleet = any(fleet_loc in our_unit_locs for fleet_loc in fleets)
|
||||
# Identify the power that owns the army at start (if any)
|
||||
army_owner = None
|
||||
for pwr, locs in power_units.items():
|
||||
if start in locs:
|
||||
army_owner = pwr
|
||||
break
|
||||
|
||||
# Determine if we own any of the required fleets
|
||||
our_fleet_count = sum(1 for fleet_loc in fleets if fleet_loc in our_unit_locs)
|
||||
|
||||
# Format the fleet path nicely
|
||||
fleet_path = " + ".join(f"{sea_regions.get(f, f)}" for f in fleets)
|
||||
fleet_path = " + ".join(sea_regions.get(f, f) for f in fleets)
|
||||
|
||||
# Create a list of destinations with context
|
||||
for dest in destinations:
|
||||
# Get destination owner if any
|
||||
dest_owner = None
|
||||
for pwr, centers in power_centers.items():
|
||||
if dest in centers:
|
||||
dest_owner = pwr
|
||||
break
|
||||
|
||||
# Determine if destination is a supply center
|
||||
is_sc = dest in game.map.scs
|
||||
sc_note = " (SC)" if is_sc else ""
|
||||
|
||||
# Create the basic convoy description
|
||||
# Create base convoy description
|
||||
convoy_desc = f"A {start} -> {dest}{sc_note} via {fleet_path}"
|
||||
|
||||
# Add strategic notes
|
||||
if is_our_army:
|
||||
category = "YOUR ARMY CONVOYS"
|
||||
convoys[category].append(f"{convoy_desc}")
|
||||
elif is_our_fleet:
|
||||
category = "YOUR FLEET CONVOYS"
|
||||
convoys[category].append(f"{convoy_desc} (you provide the convoy)")
|
||||
|
||||
# Add strategic context based on relationships
|
||||
if army_owner == power_name:
|
||||
category = "YOUR CONVOYS"
|
||||
if dest_owner:
|
||||
note = f"attack {dest_owner}'s position"
|
||||
else:
|
||||
note = "gain strategic position" if not is_sc else "capture neutral SC"
|
||||
convoys[category].append(f"{convoy_desc} ({note})")
|
||||
|
||||
elif our_fleet_count > 0:
|
||||
category = "CONVOYS YOU CAN ENABLE"
|
||||
# Add diplomatic context
|
||||
if army_owner:
|
||||
if dest_owner == power_name:
|
||||
note = f"WARNING: {army_owner} could attack your SC"
|
||||
else:
|
||||
note = f"help {army_owner} attack {dest_owner or 'neutral'} position"
|
||||
else:
|
||||
note = "potential diplomatic bargaining chip"
|
||||
convoys[category].append(f"{convoy_desc} ({note})")
|
||||
|
||||
else:
|
||||
category = "ENEMY CONVOYS"
|
||||
convoys[category].append(f"{convoy_desc} (possible enemy convoy)")
|
||||
# Analyze if this convoy represents opportunity or threat
|
||||
if dest_owner == power_name:
|
||||
category = "THREATS TO WATCH"
|
||||
note = f"{army_owner or 'potential'} attack on your position"
|
||||
elif army_owner and dest_owner:
|
||||
category = "ALLIED OPPORTUNITIES"
|
||||
note = f"{army_owner} could attack {dest_owner} - potential alliance"
|
||||
else:
|
||||
category = "ALLIED OPPORTUNITIES"
|
||||
note = "potential diplomatic leverage"
|
||||
|
||||
convoys[category].append(f"{convoy_desc} ({note})")
|
||||
|
||||
# Format output
|
||||
output = "CONVOY POSSIBILITIES:\n\n"
|
||||
|
||||
# Log convoy counts for debugging
|
||||
convoy_counts = {category: len(convoys[category]) for category in convoys}
|
||||
logger.debug(f"CONVOYS | {power_name} | Counts: " +
|
||||
", ".join(f"{category}: {count}" for category, count in convoy_counts.items()))
|
||||
|
||||
for category, convoy_list in convoys.items():
|
||||
if convoy_list:
|
||||
output += f"{category}:\n"
|
||||
|
|
@ -538,8 +608,12 @@ def generate_threat_assessment(game, board_state, power_name):
|
|||
output = "THREAT ASSESSMENT:\n"
|
||||
if not threats:
|
||||
output += " No immediate threats detected.\n\n"
|
||||
logger.debug(f"THREATS | {power_name} | No immediate threats detected")
|
||||
return output
|
||||
|
||||
# Log threat counts for debugging
|
||||
logger.debug(f"THREATS | {power_name} | Detected {len(threats)} threats from {len(set(t[0] for t in threats))} powers")
|
||||
|
||||
for (enemy_pwr, code, targets) in threats:
|
||||
output += f" {enemy_pwr}'s {code} threatens {', '.join(targets)}\n"
|
||||
output += "\n"
|
||||
|
|
@ -606,4 +680,13 @@ def generate_sc_projection(game, board_state, power_name):
|
|||
best_case = len(our_centers) + len(neutral_gains) + len(contestable)
|
||||
worst_case = len(our_centers) - len(at_risk)
|
||||
output += f" Next-phase range: {worst_case} to {best_case} centers\n\n"
|
||||
|
||||
# Log SC projection for debugging
|
||||
logger.debug(f"SC_PROJ | {power_name} | " +
|
||||
f"Current: {len(our_centers)}, " +
|
||||
f"Neutral gains: {len(neutral_gains)}, " +
|
||||
f"Contestable: {len(contestable)}, " +
|
||||
f"At risk: {len(at_risk)}, " +
|
||||
f"Range: {worst_case}-{best_case}")
|
||||
|
||||
return output
|
||||
|
|
|
|||
123
lm_game.py
123
lm_game.py
|
|
@ -22,9 +22,11 @@ from ai_diplomacy.utils import (
|
|||
from ai_diplomacy.negotiations import conduct_negotiations
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
from ai_diplomacy.long_story_short import configure_context_manager
|
||||
from ai_diplomacy.clients import configure_logging
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Configure logger with a consistent format
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -32,12 +34,23 @@ logging.basicConfig(
|
|||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
# Configure specific loggers to reduce noise
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("anthropic").setLevel(logging.WARNING)
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
|
||||
# Ensure our application loggers are at appropriate levels
|
||||
logging.getLogger("client").setLevel(logging.INFO)
|
||||
logging.getLogger("ai_diplomacy").setLevel(logging.INFO)
|
||||
|
||||
|
||||
def my_summary_callback(system_prompt, user_prompt, model_name):
|
||||
# Route to the desired model specified by the command-line argument
|
||||
client = load_model_client(model_name, emptysystem=True)
|
||||
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
|
||||
# Pseudo-code for generating a response:
|
||||
logger.debug(f"SUMMARY | Requesting phase summary from {model_name}")
|
||||
return client.generate_response(combined_prompt, empty_system=True)
|
||||
|
||||
|
||||
|
|
@ -78,6 +91,29 @@ def parse_arguments():
|
|||
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
|
||||
),
|
||||
)
|
||||
# Logging configuration options
|
||||
parser.add_argument(
|
||||
"--log_full_prompts",
|
||||
action="store_true",
|
||||
help="Log the full prompts sent to models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_full_responses",
|
||||
action="store_true",
|
||||
help="Log the full responses from models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging including HTTP connection details",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the logging level",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -101,7 +137,7 @@ def save_game_state(game, result_folder, game_file_path, model_error_stats, args
|
|||
output_path = game_file_path
|
||||
# If final file exists, append timestamp
|
||||
if os.path.exists(output_path):
|
||||
logger.info("Game file already exists, saving with unique filename.")
|
||||
logger.info("STORAGE | Final game file already exists, saving with unique timestamp")
|
||||
output_path = f"{output_path}_{timestamp}.json"
|
||||
|
||||
# Save game state
|
||||
|
|
@ -114,11 +150,21 @@ def save_game_state(game, result_folder, game_file_path, model_error_stats, args
|
|||
overview_file.write(json.dumps(game.power_model_map) + "\n")
|
||||
overview_file.write(json.dumps(vars(args)) + "\n")
|
||||
|
||||
logger.info(f"Saved game checkpoint to: {output_path}")
|
||||
logger.info(f"STORAGE | Game checkpoint saved to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
# Configure logging
|
||||
log_level = getattr(logging, args.log_level)
|
||||
configure_logging(
|
||||
log_full_prompts=args.log_full_prompts,
|
||||
log_full_responses=args.log_full_responses,
|
||||
suppress_connection_logs=not args.verbose,
|
||||
log_level=log_level
|
||||
)
|
||||
|
||||
# Configure the context manager with the same summary model
|
||||
configure_context_manager(
|
||||
phase_threshold=10000,
|
||||
|
|
@ -128,7 +174,7 @@ def main():
|
|||
max_year = args.max_year
|
||||
summary_model = args.summary_model
|
||||
|
||||
logger.info("Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!")
|
||||
logger.info("GAME_START | Initializing Diplomacy game with multiple LLM agents")
|
||||
start_whole = time.time()
|
||||
|
||||
model_error_stats = defaultdict(
|
||||
|
|
@ -153,12 +199,20 @@ def main():
|
|||
# ---------------------------
|
||||
log_file_path = os.path.join(result_folder, "game.log")
|
||||
file_handler = logging.FileHandler(log_file_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setLevel(logging.DEBUG) # Ensure we capture all levels in the file
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s", datefmt="%H:%M:%S")
|
||||
)
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f"File handler added. Writing logs to {log_file_path}.")
|
||||
|
||||
# Add the handler to root logger to capture all modules' logs
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
# Also add to specific loggers we care about most for summarization
|
||||
logging.getLogger("ai_diplomacy.long_story_short").addHandler(file_handler)
|
||||
logging.getLogger("ai_diplomacy.long_story_short").setLevel(logging.DEBUG)
|
||||
|
||||
logger.info(f"LOGGING | File handler configured to write logs to {log_file_path}")
|
||||
logger.info(f"LOGGING | Capturing detailed context management logs at DEBUG level")
|
||||
|
||||
# File paths
|
||||
manifesto_path = f"{result_folder}/game_manifesto.txt"
|
||||
|
|
@ -181,16 +235,16 @@ def main():
|
|||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) != len(powers_order):
|
||||
logger.error(
|
||||
f"Expected {len(powers_order)} models for --power-models but got {len(provided_models)}. Exiting."
|
||||
f"CONFIG_ERROR | Expected {len(powers_order)} models in --models argument but got {len(provided_models)}. Exiting."
|
||||
)
|
||||
return
|
||||
game.power_model_map = dict(zip(powers_order, provided_models))
|
||||
else:
|
||||
game.power_model_map = assign_models_to_powers(randomize=True)
|
||||
|
||||
logger.debug("Power model assignments:")
|
||||
logger.debug("POWERS | Model assignments:")
|
||||
for power, model_id in game.power_model_map.items():
|
||||
logger.debug(f"{power} => type={type(model_id)}, value={model_id}")
|
||||
logger.debug(f"POWERS | {power} assigned to {model_id}")
|
||||
|
||||
# Also, if you prefer to fix the negotiation function:
|
||||
# We could do a one-liner ensuring all model_id are strings:
|
||||
|
|
@ -198,7 +252,7 @@ def main():
|
|||
if not isinstance(game.power_model_map[p], str):
|
||||
game.power_model_map[p] = str(game.power_model_map[p])
|
||||
|
||||
logger.info("Post-cleanup: Verified all power model IDs are strings.")
|
||||
logger.debug("POWERS | Verified all power model IDs are strings")
|
||||
|
||||
round_counter = 0 # Track number of rounds
|
||||
|
||||
|
|
@ -206,28 +260,29 @@ def main():
|
|||
phase_start = time.time()
|
||||
current_phase = game.get_current_phase()
|
||||
logger.info(
|
||||
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
|
||||
f"PHASE | {current_phase} | Starting (elapsed game time: {phase_start - start_whole:.2f}s)"
|
||||
)
|
||||
|
||||
# DEBUG: Print the short phase to confirm
|
||||
logger.info(f"INFO: The current short phase is '{game.current_short_phase}'")
|
||||
# Get the current short phase
|
||||
logger.debug(f"PHASE | Current short phase: '{game.current_short_phase}'")
|
||||
|
||||
# Prevent unbounded simulation based on year
|
||||
year_str = current_phase[1:5]
|
||||
year_int = int(year_str)
|
||||
if year_int > max_year:
|
||||
logger.info(f"Reached year {year_int}, stopping the test game early.")
|
||||
logger.info(f"GAME_END | Reached year limit ({year_int} > {max_year}), terminating game")
|
||||
break
|
||||
|
||||
# If it's a movement phase (e.g. ends with "M"), conduct negotiations
|
||||
if game.current_short_phase.endswith("M"):
|
||||
logger.info("Starting negotiation phase block...")
|
||||
logger.info(f"NEGOTIATIONS | {current_phase} | Starting diplomacy round")
|
||||
conversation_messages = conduct_negotiations(
|
||||
game,
|
||||
game_history,
|
||||
model_error_stats,
|
||||
max_rounds=args.num_negotiation_rounds,
|
||||
)
|
||||
logger.debug(f"NEGOTIATIONS | {current_phase} | Completed with {len(conversation_messages)} messages")
|
||||
else:
|
||||
conversation_messages = []
|
||||
|
||||
|
|
@ -237,6 +292,8 @@ def main():
|
|||
for p_name, p_obj in game.powers.items()
|
||||
if not p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
logger.info(f"ORDERS | {current_phase} | Requesting orders from {len(active_powers)} active powers")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
|
|
@ -247,7 +304,7 @@ def main():
|
|||
client = load_model_client(model_id, power_name=power_name)
|
||||
possible_orders = gather_possible_orders(game, power_name)
|
||||
if not possible_orders:
|
||||
logger.info(f"No orderable locations for {power_name}; skipping.")
|
||||
logger.info(f"ORDERS | {power_name} | No orderable locations, skipping")
|
||||
continue
|
||||
board_state = game.get_state()
|
||||
|
||||
|
|
@ -263,24 +320,22 @@ def main():
|
|||
model_error_stats,
|
||||
)
|
||||
futures[future] = power_name
|
||||
logger.debug(f"Submitted get_valid_orders task for {power_name}.")
|
||||
logger.debug(f"ORDERS | {power_name} | Requested orders from {model_id}")
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
p_name = futures[future]
|
||||
try:
|
||||
orders = future.result()
|
||||
logger.debug(f"Validated orders for {p_name}: {orders}")
|
||||
if orders:
|
||||
logger.debug(f"ORDERS | {p_name} | Received {len(orders)} valid orders")
|
||||
game.set_orders(p_name, orders)
|
||||
logger.debug(
|
||||
f"Set orders for {p_name} in {game.current_short_phase}: {orders}"
|
||||
)
|
||||
logger.debug(f"ORDERS | {p_name} | Orders set for {game.current_short_phase}")
|
||||
else:
|
||||
logger.debug(f"No valid orders returned for {p_name}.")
|
||||
logger.warning(f"ORDERS | {p_name} | No valid orders returned")
|
||||
except Exception as exc:
|
||||
logger.error(f"LLM request failed for {p_name}: {exc}")
|
||||
logger.error(f"ORDERS | {p_name} | Request failed: {str(exc)[:150]}")
|
||||
|
||||
logger.info("Processing orders...\n")
|
||||
logger.info(f"PROCESSING | {current_phase} | Processing orders")
|
||||
# Pass the summary model to the callback via a lambda function
|
||||
phase_data = game.process(
|
||||
phase_summary_callback=lambda sys, usr: my_summary_callback(
|
||||
|
|
@ -307,42 +362,46 @@ def main():
|
|||
game.order_history[current_phase][power_name],
|
||||
results,
|
||||
)
|
||||
logger.info("Phase complete.\n")
|
||||
logger.info(f"PROCESSING | {current_phase} | Phase completed")
|
||||
|
||||
# Retrieve and log the summary of the phase
|
||||
summary_text = phase_data.summary or "(No summary found.)"
|
||||
border = "=" * 80
|
||||
logger.info(
|
||||
f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}"
|
||||
f"SUMMARY | {phase_data.name} | Phase summary: {len(summary_text)} chars"
|
||||
)
|
||||
logger.debug(f"SUMMARY | {phase_data.name} | Full text:\n{border}\n{summary_text}\n{border}")
|
||||
|
||||
# Append the summary to the manifesto file
|
||||
with open(manifesto_path, "a") as f:
|
||||
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
|
||||
|
||||
phase_duration = time.time() - phase_start
|
||||
logger.debug(f"PHASE | {current_phase} | Completed in {phase_duration:.2f}s")
|
||||
|
||||
# Increment round counter after processing each phase
|
||||
round_counter += 1
|
||||
|
||||
# Save every 5 rounds
|
||||
if round_counter % 5 == 0:
|
||||
logger.info(f"Saving checkpoint after round {round_counter}...")
|
||||
logger.info(f"CHECKPOINT | Saving after round {round_counter}")
|
||||
save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False)
|
||||
|
||||
# Check if we've exceeded the max year
|
||||
year_str = current_phase[1:5]
|
||||
year_int = int(year_str)
|
||||
if year_int > max_year:
|
||||
logger.info(f"Reached year {year_int}, stopping the test game early.")
|
||||
logger.info(f"GAME_END | Reached year limit ({year_int} > {max_year}), terminating game")
|
||||
break
|
||||
|
||||
# Save final result
|
||||
duration = time.time() - start_whole
|
||||
logger.info(f"Game ended after {duration:.2f}s. Saving final state...")
|
||||
logger.info(f"GAME_END | Duration: {duration:.2f}s | Saving final state")
|
||||
|
||||
save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=True)
|
||||
|
||||
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
|
||||
logger.info("Done.")
|
||||
logger.info(f"STORAGE | Game data saved in: {result_folder}")
|
||||
logger.info("GAME_END | Simulation complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue