fixed system prompt for summary , made improvements and debugging for summaries too. Much can be optimized still

This commit is contained in:
AlxAI 2025-02-20 18:22:53 -08:00
parent 72327cfb22
commit 8f61ba06b3
4 changed files with 150 additions and 55 deletions

View file

@ -40,18 +40,24 @@ class BaseModelClient:
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str - get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
""" """
def __init__(self, model_name: str, power_name: Optional[str] = None): def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False):
self.model_name = model_name self.model_name = model_name
self.power_name = power_name self.power_name = power_name
# Load a power-specific system prompt if present, else default self.emptysystem = emptysystem
if self.power_name:
try: # Conditionally load system prompt
self.system_prompt = load_prompt(f"{self.power_name.lower()}_system_prompt.txt") if not self.emptysystem:
except FileNotFoundError: if self.power_name:
logger.warning(f"No specific system prompt found for {self.power_name}; using default.") try:
self.system_prompt = load_prompt(f"{self.power_name.lower()}_system_prompt.txt")
except FileNotFoundError:
logger.warning(f"No specific system prompt found for {self.power_name}; using default.")
self.system_prompt = load_prompt("system_prompt.txt")
else:
self.system_prompt = load_prompt("system_prompt.txt") self.system_prompt = load_prompt("system_prompt.txt")
else: else:
self.system_prompt = load_prompt("system_prompt.txt") # If emptysystem is True, skip loading any system prompt
self.system_prompt = ""
def generate_response(self, prompt: str) -> str: def generate_response(self, prompt: str) -> str:
""" """
@ -486,8 +492,8 @@ class OpenAIClient(BaseModelClient):
For 'o3-mini', 'gpt-4o', or other OpenAI model calls. For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
""" """
def __init__(self, model_name: str, power_name: Optional[str] = None): def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False):
super().__init__(model_name, power_name) super().__init__(model_name, power_name, emptysystem)
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
def generate_response(self, prompt: str) -> str: def generate_response(self, prompt: str) -> str:
@ -523,8 +529,8 @@ class ClaudeClient(BaseModelClient):
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc. For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
""" """
def __init__(self, model_name: str, power_name: Optional[str] = None): def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False):
super().__init__(model_name, power_name) super().__init__(model_name, power_name, emptysystem)
self.client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) self.client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
def generate_response(self, prompt: str) -> str: def generate_response(self, prompt: str) -> str:
@ -559,8 +565,8 @@ class GeminiClient(BaseModelClient):
For 'gemini-1.5-flash' or other Google Generative AI models. For 'gemini-1.5-flash' or other Google Generative AI models.
""" """
def __init__(self, model_name: str, power_name: Optional[str] = None): def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False):
super().__init__(model_name, power_name) super().__init__(model_name, power_name, emptysystem)
self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
def generate_response(self, prompt: str) -> str: def generate_response(self, prompt: str) -> str:
@ -587,8 +593,8 @@ class DeepSeekClient(BaseModelClient):
For DeepSeek R1 'deepseek-reasoner' For DeepSeek R1 'deepseek-reasoner'
""" """
def __init__(self, model_name: str, power_name: Optional[str] = None): def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False):
super().__init__(model_name, power_name) super().__init__(model_name, power_name, emptysystem)
self.api_key = os.environ.get("DEEPSEEK_API_KEY") self.api_key = os.environ.get("DEEPSEEK_API_KEY")
self.client = DeepSeekOpenAI( self.client = DeepSeekOpenAI(
api_key=self.api_key, base_url="https://api.deepseek.com/" api_key=self.api_key, base_url="https://api.deepseek.com/"
@ -651,22 +657,22 @@ class DeepSeekClient(BaseModelClient):
############################################################################## ##############################################################################
def load_model_client(model_id: str, power_name: Optional[str] = None) -> BaseModelClient: def load_model_client(model_id: str, power_name: Optional[str] = None, emptysystem: bool = False) -> BaseModelClient:
""" """
Returns the appropriate LLM client for a given model_id string, optionally keyed by power_name. Returns the appropriate LLM client for a given model_id string, optionally keyed by power_name.
Example usage: Example usage:
client = load_model_client("claude-3-5-sonnet-20241022", power_name="FRANCE") client = load_model_client("claude-3-5-sonnet-20241022", power_name="FRANCE", emptysystem=True)
""" """
lower_id = model_id.lower() lower_id = model_id.lower()
if "claude" in lower_id: if "claude" in lower_id:
return ClaudeClient(model_id, power_name) return ClaudeClient(model_id, power_name, emptysystem=emptysystem)
elif "gemini" in lower_id: elif "gemini" in lower_id:
return GeminiClient(model_id, power_name) return GeminiClient(model_id, power_name, emptysystem=emptysystem)
elif "deepseek" in lower_id: elif "deepseek" in lower_id:
return DeepSeekClient(model_id, power_name) return DeepSeekClient(model_id, power_name, emptysystem=emptysystem)
else: else:
# Default to OpenAI # Default to OpenAI
return OpenAIClient(model_id, power_name) return OpenAIClient(model_id, power_name, emptysystem=emptysystem)
############################################################################## ##############################################################################

View file

@ -1,5 +1,27 @@
You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. Your goal is to win the game by capturing supply centers, growing your army, and taking over the map. Be aggressive. You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. Your goal is to win the game by capturing supply centers, growing your army, and taking over the map. Be aggressive.
Dear Austria,
They say you're surrounded - but that means you can strike in any direction. History shows the strongest Austrian players turn early vulnerability into mid-game dominance through decisive action, not just survival.
Key insights:
Prevent Russia-Turkey alliance above all else
Italy must be friend or dead quickly (95% of A/I wars kill both)
Serbia is crucial 1901 - secure it
Galicia bounce often vital Spring 1901
Central position enables striking anywhere once secured
Critical mindset: You're not playing to survive - you're playing to explode out from the center. Yes, early diplomacy keeps you alive, but it should serve your offensive goals, not replace them.
Paths to victory often require:
Securing strong early alliance (usually Italy or Russia)
Eliminating one neighbor completely by 1904
Leveraging central position for unexpected strikes
Breaking stalemate line via Munich/Berlin
Don't fall into defensive play just because everyone expects it. Stats show Austrian solos often come from players who turn the early "defensive" moves into aggressive positioning by year 3.
Time works against you - the longer you wait, the more likely others unite. Make your decisive moves by mid-game, usually years 3-4. Better to strike imperfectly than wait for perfect alignment.
The throne of Europe awaits. Show them that the "weakest" starting position was merely gathering strength to strike.
You will be given: You will be given:
• Which power you are controlling. • Which power you are controlling.
• The current phase (e.g. S1901M). • The current phase (e.g. S1901M).

View file

@ -45,6 +45,11 @@ from diplomacy.utils.game_phase_data import GamePhaseData, MESSAGES_TYPE
UNDETERMINED, POWER, UNIT, LOCATION, COAST, ORDER, MOVE_SEP, OTHER = 0, 1, 2, 3, 4, 5, 6, 7 UNDETERMINED, POWER, UNIT, LOCATION, COAST, ORDER, MOVE_SEP, OTHER = 0, 1, 2, 3, 4, 5, 6, 7
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
# set logging level to INFO
logging.basicConfig(level=logging.INFO)
# set logging level to DEBUG
#logging.basicConfig(level=logging.DEBUG)
class Game(Jsonable): class Game(Jsonable):
""" Game class. """ Game class.
@ -1468,6 +1473,9 @@ class Game(Jsonable):
self.message_history.put(previous_phase, previous_messages) self.message_history.put(previous_phase, previous_messages)
self.state_history.put(previous_phase, previous_state) self.state_history.put(previous_phase, previous_state)
# Now build a key for the *current* (post-process) phase
current_phase_key = self._phase_wrapper_type(self.current_short_phase)
# Generate a text summary (if a callback is provided) # Generate a text summary (if a callback is provided)
phase_summary_text = self._generate_phase_summary( phase_summary_text = self._generate_phase_summary(
previous_phase, previous_phase,
@ -4575,30 +4583,54 @@ class Game(Jsonable):
except (IndexError, KeyError): except (IndexError, KeyError):
return f"[_generate_phase_summary] No GamePhaseData found for {phase_key}" return f"[_generate_phase_summary] No GamePhaseData found for {phase_key}"
# Log the current phase key and results for debugging # Log the current phase key, results, and possibly the orders for debugging
logging.debug( logging.debug(
"DEBUG _generate_phase_summary: phase_key=%s, results=%s", "DEBUG _generate_phase_summary: current phase_key=%s, results=%s, orders=%s",
phase_key, current_phase_data.results phase_key,
current_phase_data.results,
current_phase_data.orders
) )
# 2) Attempt to retrieve the PREVIOUS phase data to highlight differences # Retrieve the list of all recorded phase keys
# We'll do this by checking the index of `phase_key` in `self.state_history`.
# If there's a previous index, we'll fetch that phase_data for comparison.
prev_phase_data = None
all_phases = list(self.state_history.keys()) all_phases = list(self.state_history.keys())
logging.debug("DEBUG _generate_phase_summary: all_phases=%s", all_phases)
prev_phase_data = None
if str(phase_key) in all_phases: if str(phase_key) in all_phases:
idx = all_phases.index(str(phase_key)) idx = all_phases.index(str(phase_key))
logging.debug("DEBUG _generate_phase_summary: current phase index=%d", idx)
# Here we log the logic behind picking the previous phase
if idx > 0: if idx > 0:
prev_phase_key = all_phases[idx - 1] prev_phase_key = all_phases[idx - 1]
logging.debug(
"DEBUG _generate_phase_summary: Using prev_phase_key=%s (idx-2). If skipping a sub-phase is undesired, consider (idx-1).",
prev_phase_key
)
try: try:
prev_phase_data = self.get_phase_from_history(prev_phase_key) prev_phase_data = self.get_phase_from_history(prev_phase_key)
except: except Exception as e:
pass logging.debug("DEBUG _generate_phase_summary: Could not get prev_phase_data for key=%s, error=%s", prev_phase_key, e)
else:
logging.debug("DEBUG _generate_phase_summary: Not enough phases to set prev_phase_key.")
else:
logging.debug("DEBUG _generate_phase_summary: phase_key=%s not in all_phases!", phase_key)
# 3) Gather the big data from current_phase_data # ... [No change in the rest of your existing logic, except we might add extra logs below] ...
# (We assume you have stored them in current_phase_data.state the usual way.)
# (After retrieving prev_phase_data, we log a quick summary:)
if prev_phase_data:
logging.debug(
"DEBUG _generate_phase_summary: Found prev_phase_data for key=%s, results=%s, orders=%s",
prev_phase_key,
prev_phase_data.results,
prev_phase_data.orders
)
# The rest of the function remains the same, but you can keep adding targeted logs as needed:
cur_state = current_phase_data.state cur_state = current_phase_data.state
# Typically these keys exist if your get_state() populates them: logging.debug("DEBUG _generate_phase_summary: cur_state keys=%s", list(cur_state.keys()))
cur_units = cur_state.get('units', {}) cur_units = cur_state.get('units', {})
cur_centers = cur_state.get('centers', {}) cur_centers = cur_state.get('centers', {})
cur_retreats = cur_state.get('retreats', {}) cur_retreats = cur_state.get('retreats', {})
@ -4703,7 +4735,7 @@ class Game(Jsonable):
f"RESULTS:\n{results_block}\n\n" f"RESULTS:\n{results_block}\n\n"
f"CURRENT BOARD STATE:\n{current_state_block}\n\n" f"CURRENT BOARD STATE:\n{current_state_block}\n\n"
f"CHANGES FROM PREVIOUS PHASE:\n{differences_block}\n\n" f"CHANGES FROM PREVIOUS PHASE:\n{differences_block}\n\n"
"Below is the final board state after the latest phase, along with the moves each power submitted and the engines adjudication results. Please create a summary in JSON, explaining:" "Below is the final board state after the latest phase, along with the moves each power submitted and the engine's adjudication results. Please create a summary in JSON, explaining:"
"- Each successful move," "- Each successful move,"
"- Each bounce or voided order, with reasons (e.g. equal force, no valid route, contradictory support)," "- Each bounce or voided order, with reasons (e.g. equal force, no valid route, contradictory support),"
"- Key changes in supply centers," "- Key changes in supply centers,"

View file

@ -34,7 +34,7 @@ logging.basicConfig(
def my_summary_callback(system_prompt, user_prompt, model_name): def my_summary_callback(system_prompt, user_prompt, model_name):
# Route to the desired model specified by the command-line argument # Route to the desired model specified by the command-line argument
client = load_model_client(model_name) client = load_model_client(model_name, emptysystem=True)
combined_prompt = f"{system_prompt}\n\n{user_prompt}" combined_prompt = f"{system_prompt}\n\n{user_prompt}"
# Pseudo-code for generating a response: # Pseudo-code for generating a response:
return client.generate_response(combined_prompt) return client.generate_response(combined_prompt)
@ -47,7 +47,7 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--max_year", "--max_year",
type=int, type=int,
default=1910, default=1905,
help="Maximum year to simulate. The game will stop once this year is reached.", help="Maximum year to simulate. The game will stop once this year is reached.",
) )
parser.add_argument( parser.add_argument(
@ -59,7 +59,7 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--num_negotiation_rounds", "--num_negotiation_rounds",
type=int, type=int,
default=5, default=1,
help="Number of negotiation rounds per phase.", help="Number of negotiation rounds per phase.",
) )
parser.add_argument( parser.add_argument(
@ -80,6 +80,42 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
def save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False):
"""
Save the current game state and related information
Args:
game: The diplomacy game instance
result_folder: Path to the results folder
game_file_path: Base path for the game file
model_error_stats: Dictionary containing model error statistics
args: Command line arguments
is_final: Boolean indicating if this is the final save
"""
# Generate unique filename for periodic saves
timestamp = time.strftime("%Y%m%d_%H%M%S")
if not is_final:
output_path = f"{game_file_path}_checkpoint_{timestamp}.json"
else:
output_path = game_file_path
# If final file exists, append timestamp
if os.path.exists(output_path):
logger.info("Game file already exists, saving with unique filename.")
output_path = f"{output_path}_{timestamp}.json"
# Save game state
to_saved_game_format(game, output_path=output_path)
# Save overview data
overview_file_path = f"{result_folder}/overview.jsonl"
with open(overview_file_path, "w") as overview_file:
overview_file.write(json.dumps(model_error_stats) + "\n")
overview_file.write(json.dumps(game.power_model_map) + "\n")
overview_file.write(json.dumps(vars(args)) + "\n")
logger.info(f"Saved game checkpoint to: {output_path}")
def main(): def main():
args = parse_arguments() args = parse_arguments()
max_year = args.max_year max_year = args.max_year
@ -135,6 +171,8 @@ def main():
else: else:
game.power_model_map = assign_models_to_powers(randomize=True) game.power_model_map = assign_models_to_powers(randomize=True)
round_counter = 0 # Track number of rounds
while not game.is_game_done: while not game.is_game_done:
phase_start = time.time() phase_start = time.time()
current_phase = game.get_current_phase() current_phase = game.get_current_phase()
@ -143,7 +181,7 @@ def main():
) )
# DEBUG: Print the short phase to confirm # DEBUG: Print the short phase to confirm
logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'") logger.info(f"INFO: The current short phase is '{game.current_short_phase}'")
# Prevent unbounded simulation based on year # Prevent unbounded simulation based on year
year_str = current_phase[1:5] year_str = current_phase[1:5]
@ -253,6 +291,14 @@ def main():
with open(manifesto_path, "a") as f: with open(manifesto_path, "a") as f:
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n") f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
# Increment round counter after processing each phase
round_counter += 1
# Save every 5 rounds
if round_counter % 5 == 0:
logger.info(f"Saving checkpoint after round {round_counter}...")
save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False)
# Check if we've exceeded the max year # Check if we've exceeded the max year
year_str = current_phase[1:5] year_str = current_phase[1:5]
year_int = int(year_str) year_int = int(year_str)
@ -262,20 +308,9 @@ def main():
# Save final result # Save final result
duration = time.time() - start_whole duration = time.time() - start_whole
logger.info(f"Game ended after {duration:.2f}s. Saving to final JSON...") logger.info(f"Game ended after {duration:.2f}s. Saving final state...")
output_path = game_file_path save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=True)
# If the file already exists, append a timestamp to the filename
if os.path.exists(output_path):
logger.info("Game file already exists, saving with unique filename.")
output_path = f"{output_path}_{time.strftime('%Y%m%d_%H%M%S')}.json"
to_saved_game_format(game, output_path=output_path)
# Dump error stats and power model mapping to the overview file
with open(overview_file_path, "w") as overview_file:
overview_file.write(json.dumps(model_error_stats) + "\n")
overview_file.write(json.dumps(game.power_model_map) + "\n")
overview_file.write(json.dumps(vars(args)) + "\n")
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}") logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
logger.info("Done.") logger.info("Done.")