mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Added country specific prompts and more async to speed up
This commit is contained in:
parent
9fc25f2fec
commit
3b5f3015c1
7 changed files with 225 additions and 81 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -165,3 +165,4 @@ analysis_summary_debug.txt
|
|||
|
||||
./results_alpha
|
||||
/results_alpha/20250607_222757
|
||||
/ai_diplomacy/prompts/famous_leaders_prompts
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import re
|
|||
import json_repair
|
||||
import json5 # More forgiving JSON parser
|
||||
import ast
|
||||
import asyncio
|
||||
|
||||
from config import config
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from together.error import APIError as TogetherAPIError # For specific error ha
|
|||
|
||||
from config import config
|
||||
from .game_history import GameHistory
|
||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, log_llm_response_async, generate_random_seed, get_prompt_path
|
||||
|
||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
||||
|
|
@ -52,6 +52,7 @@ class BaseModelClient:
|
|||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
self.model_name = model_name
|
||||
self.prompts_dir = prompts_dir
|
||||
logger.info(f"[{model_name}] BaseModelClient initialized with prompts_dir: {prompts_dir}")
|
||||
# Load a default initially, can be overwritten by set_system_prompt
|
||||
self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
|
||||
self.max_tokens = 16000 # default unless overridden
|
||||
|
|
@ -180,7 +181,7 @@ class BaseModelClient:
|
|||
finally:
|
||||
# Log the attempt regardless of outcome
|
||||
if log_file_path: # Only log if a path is provided
|
||||
log_llm_response(
|
||||
await log_llm_response_async(
|
||||
log_file_path=log_file_path,
|
||||
model_name=self.model_name,
|
||||
power_name=power_name,
|
||||
|
|
@ -441,7 +442,18 @@ class BaseModelClient:
|
|||
agent_private_diary_str: Optional[str] = None, # Added
|
||||
) -> str:
|
||||
# MINIMAL CHANGE: Just change to load unformatted version conditionally
|
||||
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
|
||||
# Check if country-specific prompts are enabled
|
||||
if config.COUNTRY_SPECIFIC_PROMPTS:
|
||||
# Try to load country-specific version first
|
||||
country_specific_file = get_prompt_path(f"conversation_instructions_{power_name.lower()}.txt")
|
||||
instructions = load_prompt(country_specific_file, prompts_dir=self.prompts_dir)
|
||||
|
||||
# Fall back to generic if country-specific not found
|
||||
if not instructions:
|
||||
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
|
||||
else:
|
||||
# Load generic conversation instructions
|
||||
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
|
||||
|
||||
# KEEP ORIGINAL: Use build_context_prompt as before
|
||||
context = build_context_prompt(
|
||||
|
|
@ -670,7 +682,7 @@ class BaseModelClient:
|
|||
messages_to_return = [] # Ensure empty list on general exception
|
||||
finally:
|
||||
if log_file_path:
|
||||
log_llm_response(
|
||||
await log_llm_response_async(
|
||||
log_file_path=log_file_path,
|
||||
model_name=self.model_name,
|
||||
power_name=power_name,
|
||||
|
|
@ -749,7 +761,7 @@ class BaseModelClient:
|
|||
plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}"
|
||||
finally:
|
||||
if log_file_path: # Only log if a path is provided
|
||||
log_llm_response(
|
||||
await log_llm_response_async(
|
||||
log_file_path=log_file_path,
|
||||
model_name=self.model_name,
|
||||
power_name=power_name,
|
||||
|
|
@ -797,27 +809,34 @@ class OpenAIClient(BaseModelClient):
|
|||
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
# Determine which parameter to use based on model
|
||||
completion_params = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
# Handle model-specific parameters
|
||||
# Check if model name starts with 'nectarine' or is in the specific list
|
||||
uses_max_completion_tokens = (
|
||||
self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or
|
||||
self.model_name.startswith("nectarine")
|
||||
)
|
||||
|
||||
if uses_max_completion_tokens:
|
||||
completion_params["max_completion_tokens"] = self.max_tokens
|
||||
# o4-mini, o3-mini, o3 only support default temperature of 1.0
|
||||
if self.model_name in ["o4-mini", "o3-mini", "o3"]:
|
||||
completion_params["temperature"] = 1.0
|
||||
else:
|
||||
completion_params["temperature"] = temperature
|
||||
else:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
completion_params["max_tokens"] = self.max_tokens
|
||||
completion_params["temperature"] = temperature
|
||||
|
||||
response = await self.client.chat.completions.create(**completion_params)
|
||||
|
||||
if (
|
||||
not response
|
||||
|
|
@ -971,16 +990,24 @@ class DeepSeekClient(BaseModelClient):
|
|||
random_seed = generate_random_seed()
|
||||
system_prompt_content = f"{random_seed}\n\n{self.system_prompt}"
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
# Determine which parameter to use based on model
|
||||
completion_params = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
],
|
||||
stream=False,
|
||||
temperature=temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
"stream": False,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Use max_completion_tokens for o4-mini, o3-mini models and nectarine models
|
||||
if self.model_name in ["o4-mini", "o3-mini"] or self.model_name.startswith("nectarine"):
|
||||
completion_params["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
completion_params["max_tokens"] = self.max_tokens
|
||||
|
||||
response = await self.client.chat.completions.create(**completion_params)
|
||||
|
||||
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
||||
|
||||
|
|
@ -1023,7 +1050,7 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
This client makes direct HTTP requests to the v1/responses endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None, api_key: Optional[str] = None, reasoning_effort: Optional[str] = None):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
if api_key:
|
||||
self.api_key = api_key
|
||||
|
|
@ -1032,7 +1059,20 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
self.base_url = "https://api.openai.com/v1/responses"
|
||||
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client")
|
||||
self._session = None # Lazy initialization for connection pooling
|
||||
self.reasoning_effort = reasoning_effort # For models that support reasoning effort
|
||||
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client with reasoning_effort={reasoning_effort}")
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the aiohttp session for connection pooling."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the aiohttp session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
try:
|
||||
|
|
@ -1049,51 +1089,59 @@ class OpenAIResponsesClient(BaseModelClient):
|
|||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": full_prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
|
||||
del payload["temperature"]
|
||||
del payload["max_tokens"]
|
||||
payload["max_completion_tokens"] = self.max_tokens
|
||||
# The Responses API uses max_output_tokens for all models
|
||||
payload["max_output_tokens"] = self.max_tokens
|
||||
|
||||
# Only add temperature for models that support it
|
||||
models_without_temp = ['o3', 'o4-mini', 'gpt-5-reasoning-alpha-2025-07-19', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
|
||||
if self.model_name not in models_without_temp:
|
||||
payload["temperature"] = temperature
|
||||
|
||||
# Add reasoning effort for models that support it
|
||||
reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'o4-mini-alpha-2025-07-11', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
|
||||
if self.reasoning_effort and self.model_name in reasoning_models:
|
||||
payload["reasoning"] = {"effort": self.reasoning_effort}
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
# Make the API call using aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status() # Will raise for non-2xx responses
|
||||
response_data = await response.json()
|
||||
# Make the API call using the pooled session
|
||||
session = await self._get_session()
|
||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status() # Will raise for non-2xx responses
|
||||
response_data = await response.json()
|
||||
|
||||
# Extract the text from the nested response structure
|
||||
try:
|
||||
outputs = response_data.get("output", [])
|
||||
if len(outputs) < 2:
|
||||
raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.")
|
||||
# Extract the text from the nested response structure
|
||||
try:
|
||||
outputs = response_data.get("output", [])
|
||||
if len(outputs) < 2:
|
||||
# Log the actual response for debugging
|
||||
logger.error(f"[{self.model_name}] Response structure: {json.dumps(response_data, indent=2)}")
|
||||
raise ValueError(f"[{self.model_name}] Unexpected output structure: 'output' list has < 2 items.")
|
||||
|
||||
message_output = outputs[1]
|
||||
if message_output.get("type") != "message":
|
||||
raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.")
|
||||
message_output = outputs[1]
|
||||
if message_output.get("type") != "message":
|
||||
raise ValueError(f"[{self.model_name}] Expected 'message' type in output[1], got '{message_output.get('type')}'.")
|
||||
|
||||
content_list = message_output.get("content", [])
|
||||
if not content_list:
|
||||
raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.")
|
||||
content_list = message_output.get("content", [])
|
||||
if not content_list:
|
||||
raise ValueError(f"[{self.model_name}] Empty 'content' list in message output.")
|
||||
|
||||
text_content = ""
|
||||
for content_item in content_list:
|
||||
if content_item.get("type") == "output_text":
|
||||
text_content = content_item.get("text", "")
|
||||
break
|
||||
text_content = ""
|
||||
for content_item in content_list:
|
||||
if content_item.get("type") == "output_text":
|
||||
text_content = content_item.get("text", "")
|
||||
break
|
||||
|
||||
if not text_content:
|
||||
raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.")
|
||||
if not text_content:
|
||||
raise ValueError(f"[{self.model_name}] No 'output_text' found in content or it was empty.")
|
||||
|
||||
return text_content.strip()
|
||||
return text_content.strip()
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
# Wrap parsing error in a more informative exception
|
||||
raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
# Wrap parsing error in a more informative exception
|
||||
raise ValueError(f"[{self.model_name}] Error parsing response structure: {e}") from e
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
|
||||
|
|
@ -1302,8 +1350,13 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
# Use max_completion_tokens for o4-mini, o3-mini, o3, gpt-4.1 models and nectarine models
|
||||
if self.model_name in ["o4-mini", "o3-mini", "o3", "gpt-4.1"] or self.model_name.startswith("nectarine"):
|
||||
payload["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
payload["max_tokens"] = self.max_tokens
|
||||
|
||||
#if self.model_name == "qwen/qwen3-235b-a22b" and self.base_url == "https://openrouter.ai/api/v1":
|
||||
# payload["provider"] = {
|
||||
|
|
@ -1313,7 +1366,8 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
|
||||
if (self.model_name == 'o3' or self.model_name == 'o4-mini'):
|
||||
del payload["temperature"]
|
||||
del payload["max_tokens"]
|
||||
if "max_tokens" in payload:
|
||||
del payload["max_tokens"]
|
||||
payload["max_completion_tokens"] = self.max_tokens
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -1381,13 +1435,33 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
|||
gpt-4o
|
||||
anthropic:claude-3.7-sonnet
|
||||
openai:llama-3-2-3b@https://localhost:8000#myapikey
|
||||
gpt-5-reasoning-alpha-2025-07-19:minimal
|
||||
and returns the appropriate client.
|
||||
|
||||
• If a prefix is omitted the function falls back to the original
|
||||
heuristic mapping exactly as before.
|
||||
• If an inline API-key (‘#…’) is present it overrides environment vars.
|
||||
• If an inline API-key ('#…') is present it overrides environment vars.
|
||||
• For reasoning models, effort can be specified with :minimal, :medium, or :high
|
||||
"""
|
||||
spec = _parse_model_spec(model_id)
|
||||
# Extract reasoning effort if present (before general parsing)
|
||||
reasoning_effort = None
|
||||
actual_model_id = model_id
|
||||
|
||||
# Check if this is a reasoning model with effort specified
|
||||
reasoning_models = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
|
||||
for model in reasoning_models:
|
||||
if model_id.startswith(model + ':'):
|
||||
parts = model_id.split(':', 1)
|
||||
effort_part = parts[1]
|
||||
# Check if the effort part is valid before treating it as effort
|
||||
# (it could be a prefix like "openai:")
|
||||
if effort_part.lower() in ['minimal', 'medium', 'high']:
|
||||
actual_model_id = parts[0]
|
||||
reasoning_effort = effort_part.lower()
|
||||
break
|
||||
|
||||
spec = _parse_model_spec(actual_model_id)
|
||||
logger.info(f"[load_model_client] Loading client for model_id='{model_id}', parsed spec: prefix={spec.prefix}, model={spec.model}, reasoning_effort={reasoning_effort}")
|
||||
|
||||
# Inline key overrides env; otherwise fall back as usual *per client*
|
||||
inline_key = spec.key
|
||||
|
|
@ -1421,7 +1495,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
|||
api_key=inline_key,
|
||||
)
|
||||
case Prefix.OPENAI_RESPONSES:
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort)
|
||||
case Prefix.ANTHROPIC:
|
||||
return ClaudeClient(spec.model, prompts_dir)
|
||||
case Prefix.GEMINI:
|
||||
|
|
@ -1437,27 +1511,41 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
|||
# 2. Heuristic fallback path (identical to the original behaviour) #
|
||||
# ------------------------------------------------------------------ #
|
||||
lower_id = spec.model.lower()
|
||||
logger.info(f"[load_model_client] Heuristic path: checking model='{spec.model}', lower_id='{lower_id}'")
|
||||
|
||||
# Check if this is a reasoning model that should use Responses API
|
||||
reasoning_models_requiring_responses = ['gpt-5-reasoning-alpha-2025-07-19', 'o4-mini', 'nectarine-alpha-2025-07-25', 'nectarine-alpha-new-reasoning-effort-2025-07-25']
|
||||
if spec.model in reasoning_models_requiring_responses:
|
||||
logger.info(f"[load_model_client] Selected OpenAIResponsesClient for reasoning model '{spec.model}'")
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key, reasoning_effort=reasoning_effort)
|
||||
|
||||
if lower_id == "o3-pro":
|
||||
logger.info(f"[load_model_client] Selected OpenAIResponsesClient for '{spec.model}'")
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||
|
||||
if spec.model.startswith("together-"):
|
||||
# e.g. "together-mixtral-8x7b"
|
||||
logger.info(f"[load_model_client] Selected TogetherAIClient for '{spec.model}'")
|
||||
return TogetherAIClient(spec.model.split("together-", 1)[1], prompts_dir)
|
||||
|
||||
if "openrouter" in lower_id:
|
||||
logger.info(f"[load_model_client] Selected OpenRouterClient for '{spec.model}'")
|
||||
return OpenRouterClient(spec.model, prompts_dir)
|
||||
|
||||
if "claude" in lower_id:
|
||||
logger.info(f"[load_model_client] Selected ClaudeClient for '{spec.model}'")
|
||||
return ClaudeClient(spec.model, prompts_dir)
|
||||
|
||||
if "gemini" in lower_id:
|
||||
logger.info(f"[load_model_client] Selected GeminiClient for '{spec.model}'")
|
||||
return GeminiClient(spec.model, prompts_dir)
|
||||
|
||||
if "deepseek" in lower_id:
|
||||
logger.info(f"[load_model_client] Selected DeepSeekClient for '{spec.model}'")
|
||||
return DeepSeekClient(spec.model, prompts_dir)
|
||||
|
||||
# Default: OpenAI-compatible async client
|
||||
logger.info(f"[load_model_client] No specific match found, using default OpenAIClient for '{spec.model}'")
|
||||
return OpenAIClient(
|
||||
model_name=spec.model,
|
||||
prompts_dir=prompts_dir,
|
||||
|
|
|
|||
|
|
@ -35,19 +35,23 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None, override_max_tokens: Optional[int] = None) -> DiplomacyAgent:
|
||||
"""
|
||||
Recreates an agent object from a dictionary.
|
||||
|
||||
If *override_model_id* is provided (e.g. because the CLI argument
|
||||
``--models`` was used when resuming a game), that model is loaded
|
||||
instead of the one stored in the save file.
|
||||
|
||||
If *override_max_tokens* is provided (e.g. because the CLI argument
|
||||
``--max_tokens`` was used when resuming a game), that value is used
|
||||
instead of the one stored in the save file.
|
||||
"""
|
||||
model_id = override_model_id or agent_data["model_id"]
|
||||
client = load_model_client(model_id, prompts_dir=prompts_dir)
|
||||
|
||||
# Keep the original or fallback token limit exactly as before.
|
||||
client.max_tokens = agent_data.get("max_tokens", 16000)
|
||||
# Use override if provided, otherwise use saved value, otherwise default to 16000
|
||||
client.max_tokens = override_max_tokens or agent_data.get("max_tokens", 16000)
|
||||
|
||||
agent = DiplomacyAgent(
|
||||
power_name=agent_data["power_name"],
|
||||
|
|
@ -208,9 +212,22 @@ def load_game_state(
|
|||
# --- Rebuild agents -------------------------------------------------------
|
||||
agents: Dict[str, "DiplomacyAgent"] = {}
|
||||
power_model_map: Dict[str, str] = {}
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
|
||||
# Parse token limits from run_config
|
||||
default_max_tokens = run_config.max_tokens if run_config and hasattr(run_config, 'max_tokens') else 16000
|
||||
model_max_tokens = {p: default_max_tokens for p in powers_order}
|
||||
|
||||
if run_config and hasattr(run_config, 'max_tokens_per_model') and run_config.max_tokens_per_model:
|
||||
per_model_values = [s.strip() for s in run_config.max_tokens_per_model.split(",")]
|
||||
if len(per_model_values) == 7:
|
||||
for power, token_val_str in zip(powers_order, per_model_values):
|
||||
model_max_tokens[power] = int(token_val_str)
|
||||
else:
|
||||
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
||||
|
||||
if run_config and getattr(run_config, "models", None):
|
||||
provided = [m.strip() for m in run_config.models.split(",")]
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
if len(provided) == len(powers_order):
|
||||
power_model_map = dict(zip(powers_order, provided))
|
||||
elif len(provided) == 1:
|
||||
|
|
@ -234,6 +251,7 @@ def load_game_state(
|
|||
agent_data,
|
||||
prompts_dir=prompts_dir_from_config,
|
||||
override_model_id=override_id,
|
||||
override_max_tokens=model_max_tokens.get(power_name),
|
||||
)
|
||||
|
||||
# --- Rebuild GameHistory --------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -175,16 +175,31 @@ def construct_order_generation_prompt(
|
|||
_ = load_prompt("few_shot_example.txt", prompts_dir=prompts_dir) # Loaded but not used, as per original logic
|
||||
# Pick the phase-specific instruction file (using unformatted versions)
|
||||
phase_code = board_state["phase"][-1] # 'M' (movement), 'R', or 'A' / 'B'
|
||||
|
||||
# Determine base instruction file name
|
||||
if phase_code == "M":
|
||||
instructions_file = get_prompt_path("order_instructions_movement_phase.txt")
|
||||
base_instruction_file = "order_instructions_movement_phase"
|
||||
elif phase_code in ("A", "B"): # builds / adjustments
|
||||
instructions_file = get_prompt_path("order_instructions_adjustment_phase.txt")
|
||||
base_instruction_file = "order_instructions_adjustment_phase"
|
||||
elif phase_code == "R": # retreats
|
||||
instructions_file = get_prompt_path("order_instructions_retreat_phase.txt")
|
||||
base_instruction_file = "order_instructions_retreat_phase"
|
||||
else: # unexpected – default to movement rules
|
||||
instructions_file = get_prompt_path("order_instructions_movement_phase.txt")
|
||||
|
||||
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
|
||||
base_instruction_file = "order_instructions_movement_phase"
|
||||
|
||||
# Check if country-specific prompts are enabled
|
||||
if config.COUNTRY_SPECIFIC_PROMPTS:
|
||||
# Try to load country-specific version first
|
||||
country_specific_file = get_prompt_path(f"{base_instruction_file}_{power_name.lower()}.txt")
|
||||
instructions = load_prompt(country_specific_file, prompts_dir=prompts_dir)
|
||||
|
||||
# Fall back to generic if country-specific not found
|
||||
if not instructions:
|
||||
instructions_file = get_prompt_path(f"{base_instruction_file}.txt")
|
||||
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
|
||||
else:
|
||||
# Load generic instruction file
|
||||
instructions_file = get_prompt_path(f"{base_instruction_file}.txt")
|
||||
instructions = load_prompt(instructions_file, prompts_dir=prompts_dir)
|
||||
_use_simple = config.SIMPLE_PROMPTS
|
||||
|
||||
include_order_history = False # defaulting to not include order history in order generation prompt for now
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ class Configuration(BaseSettings):
|
|||
log_file_path: Path | None = None
|
||||
USE_UNFORMATTED_PROMPTS: bool = False
|
||||
SIMPLE_PROMPTS: bool = True
|
||||
COUNTRY_SPECIFIC_PROMPTS: bool = False
|
||||
|
||||
# Default models for tasks
|
||||
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"
|
||||
|
|
|
|||
20
lm_game.py
20
lm_game.py
|
|
@ -173,6 +173,18 @@ def parse_arguments():
|
|||
"Set to false (0 / false / no) to use original single-step formatted prompts."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--country_specific_prompts",
|
||||
type=_str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help=(
|
||||
"When true (1 / true / yes) enables country-specific order and conversation prompts. "
|
||||
"Each power will use their own custom prompts if available (e.g., order_instructions_movement_phase_france.txt). "
|
||||
"Falls back to generic prompts if country-specific not found."
|
||||
),
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -223,6 +235,14 @@ async def main():
|
|||
else:
|
||||
config.USE_UNFORMATTED_PROMPTS = False
|
||||
logger.info("Using original single-step formatted prompts")
|
||||
|
||||
# Handle country-specific prompts flag
|
||||
if args.country_specific_prompts:
|
||||
config.COUNTRY_SPECIFIC_PROMPTS = True
|
||||
logger.info("Country-specific prompts enabled - powers will use their custom prompts when available")
|
||||
else:
|
||||
config.COUNTRY_SPECIFIC_PROMPTS = False
|
||||
logger.info("Using generic prompts for all powers")
|
||||
|
||||
if args.max_year == None:
|
||||
if args.end_at_phase:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue