AI_Diplomacy/ai_diplomacy/narrative.py

106 lines
No EOL
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Generate engaging narrative summaries and transparently patch the Diplomacy
Game engine to use them.
Usage: simply import `ai_diplomacy.narrative` *before* the game loop starts
(e.g. at the top of `lm_game.py`). Import side-effects monkey-patch
`diplomacy.engine.game.Game._generate_phase_summary` so that:
1. The original (statistical) summary logic still runs.
2. The returned text is stored in `GamePhaseData.statistical_summary`.
3. A short narrative is produced via OpenAI `o3` and saved as the main
`.summary`.
"""
from __future__ import annotations
import logging
import os
from typing import Callable
from openai import OpenAI # Import the new OpenAI client
from diplomacy.engine.game import Game
LOGGER = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
OPENAI_MODEL = os.getenv("AI_DIPLOMACY_NARRATIVE_MODEL", "gpt-o3")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
LOGGER.warning("OPENAI_API_KEY not set narrative summaries will be stubbed.")
# ---------------------------------------------------------------------------
# Helper to call the model synchronously
# ---------------------------------------------------------------------------
def _call_openai(statistical_summary: str, phase_key: str) -> str:
"""Return a 24 sentence spectator-friendly narrative."""
if not OPENAI_API_KEY:
return "(Narrative generation disabled missing API key)."
system = (
"You are an energetic e-sports commentator narrating a game of Diplomacy. "
"Turn the provided phase recap into a concise, thrilling story (max 4 sentences). "
"Highlight pivotal moves, supply-center swings, betrayals, and momentum shifts."
)
user = f"PHASE {phase_key}\n\nSTATISTICAL SUMMARY:\n{statistical_summary}\n\nNow narrate this phase for spectators."
try:
# Initialize the OpenAI client with the API key
client = OpenAI(api_key=OPENAI_API_KEY)
# Use the new API format
resp = client.chat.completions.create(
model="o3",
messages=[{"role": "system", "content": system}, {"role": "user", "content": user}],
)
return resp.choices[0].message.content.strip()
except Exception as exc: # Broad we only log and degrade gracefully
LOGGER.error("Narrative generation failed: %s", exc, exc_info=True)
return "(Narrative generation failed)"
# ---------------------------------------------------------------------------
# Patch _generate_phase_summary
# ---------------------------------------------------------------------------
_original_gps: Callable = Game._generate_phase_summary # type: ignore[attr-defined]
def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None): # type: ignore[override]
# 1) Call original implementation → statistical summary
statistical = _original_gps(self, phase_key, summary_callback)
LOGGER.debug(f"[{phase_key}] Original summary returned: {statistical!r}")
# 2) Persist statistical summary separately
phase_data = None
try:
phase_data = self.get_phase_from_history(str(phase_key))
if hasattr(phase_data, "statistical_summary"):
LOGGER.debug(f"[{phase_key}] Assigning to phase_data.statistical_summary: {statistical!r}")
phase_data.statistical_summary = statistical # type: ignore[attr-defined]
else:
LOGGER.warning(f"[{phase_key}] phase_data object does not have attribute 'statistical_summary'. Type: {type(phase_data)}")
except Exception as exc:
LOGGER.warning("Could not retrieve phase_data or store statistical_summary for %s: %s", phase_key, exc)
# 3) Generate narrative summary
narrative = _call_openai(statistical, phase_key)
# 4) Save narrative as the canonical summary
try:
if phase_data:
phase_data.summary = narrative # type: ignore[attr-defined]
self.phase_summaries[str(phase_key)] = narrative # type: ignore[attr-defined]
LOGGER.debug(f"[{phase_key}] Narrative summary stored successfully.")
else:
LOGGER.warning(f"[{phase_key}] Cannot store narrative summary because phase_data is None.")
except Exception as exc:
LOGGER.warning("Could not store narrative summary for %s: %s", phase_key, exc)
return narrative
# Monkey-patch
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")