diff --git a/ai_diplomacy/narrative.py b/ai_diplomacy/narrative.py new file mode 100644 index 0000000..756b469 --- /dev/null +++ b/ai_diplomacy/narrative.py @@ -0,0 +1,106 @@ +"""Generate engaging narrative summaries and transparently patch the Diplomacy +Game engine to use them. + +Usage: simply import `ai_diplomacy.narrative` *before* the game loop starts +(e.g. at the top of `lm_game.py`). Import side-effects monkey-patch +`diplomacy.engine.game.Game._generate_phase_summary` so that: + +1. The original (statistical) summary logic still runs. +2. The returned text is stored in `GamePhaseData.statistical_summary`. +3. A short narrative is produced via OpenAI `o3` and saved as the main + `.summary`. +""" +from __future__ import annotations + +import logging +import os +from typing import Callable + +from openai import OpenAI # Import the new OpenAI client +from diplomacy.engine.game import Game + +LOGGER = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +OPENAI_MODEL = os.getenv("AI_DIPLOMACY_NARRATIVE_MODEL", "gpt-o3") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +if not OPENAI_API_KEY: + LOGGER.warning("OPENAI_API_KEY not set – narrative summaries will be stubbed.") + +# --------------------------------------------------------------------------- +# Helper to call the model synchronously +# --------------------------------------------------------------------------- + +def _call_openai(statistical_summary: str, phase_key: str) -> str: + """Return a 2–4 sentence spectator-friendly narrative.""" + if not OPENAI_API_KEY: + return "(Narrative generation disabled – missing API key)." + + system = ( + "You are an energetic e-sports commentator narrating a game of Diplomacy. " + "Turn the provided phase recap into a concise, thrilling story (max 4 sentences). " + "Highlight pivotal moves, supply-center swings, betrayals, and momentum shifts." + ) + user = f"PHASE {phase_key}\n\nSTATISTICAL SUMMARY:\n{statistical_summary}\n\nNow narrate this phase for spectators." + + try: + # Initialize the OpenAI client with the API key + client = OpenAI(api_key=OPENAI_API_KEY) + + # Use the new API format + resp = client.chat.completions.create( + model="o3", + messages=[{"role": "system", "content": system}, {"role": "user", "content": user}], + ) + return resp.choices[0].message.content.strip() + except Exception as exc: # Broad – we only log and degrade gracefully + LOGGER.error("Narrative generation failed: %s", exc, exc_info=True) + return "(Narrative generation failed)" + +# --------------------------------------------------------------------------- +# Patch _generate_phase_summary +# --------------------------------------------------------------------------- + +_original_gps: Callable = Game._generate_phase_summary # type: ignore[attr-defined] + + +def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None): # type: ignore[override] + # 1) Call original implementation → statistical summary + statistical = _original_gps(self, phase_key, summary_callback) + LOGGER.debug(f"[{phase_key}] Original summary returned: {statistical!r}") + + # 2) Persist statistical summary separately + phase_data = None + try: + phase_data = self.get_phase_from_history(str(phase_key)) + if hasattr(phase_data, "statistical_summary"): + LOGGER.debug(f"[{phase_key}] Assigning to phase_data.statistical_summary: {statistical!r}") + phase_data.statistical_summary = statistical # type: ignore[attr-defined] + else: + LOGGER.warning(f"[{phase_key}] phase_data object does not have attribute 'statistical_summary'. Type: {type(phase_data)}") + except Exception as exc: + LOGGER.warning("Could not retrieve phase_data or store statistical_summary for %s: %s", phase_key, exc) + + # 3) Generate narrative summary + narrative = _call_openai(statistical, phase_key) + + # 4) Save narrative as the canonical summary + try: + if phase_data: + phase_data.summary = narrative # type: ignore[attr-defined] + self.phase_summaries[str(phase_key)] = narrative # type: ignore[attr-defined] + LOGGER.debug(f"[{phase_key}] Narrative summary stored successfully.") + else: + LOGGER.warning(f"[{phase_key}] Cannot store narrative summary because phase_data is None.") + except Exception as exc: + LOGGER.warning("Could not store narrative summary for %s: %s", phase_key, exc) + + return narrative + +# Monkey-patch +Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment] + +LOGGER.info("Game._generate_phase_summary patched with narrative generation.") \ No newline at end of file diff --git a/diplomacy/utils/game_phase_data.py b/diplomacy/utils/game_phase_data.py index 42cc3d1..703861a 100644 --- a/diplomacy/utils/game_phase_data.py +++ b/diplomacy/utils/game_phase_data.py @@ -27,7 +27,7 @@ class GamePhaseData(Jsonable): """ Small class to represent data for a game phase: phase name, state, orders, orders results and messages for this phase. """ - __slots__ = ['name', 'state', 'orders', 'results', 'messages', 'summary'] + __slots__ = ['name', 'state', 'orders', 'results', 'messages', 'summary', 'statistical_summary'] model = { strings.NAME: str, @@ -35,10 +35,11 @@ class GamePhaseData(Jsonable): strings.ORDERS: parsing.DictType(str, parsing.OptionalValueType(parsing.SequenceType(str))), strings.RESULTS: parsing.DictType(str, parsing.SequenceType(parsing.StringableType(common.StringableCode))), strings.MESSAGES: MESSAGES_TYPE, - 'summary': parsing.OptionalValueType(str) + 'summary': parsing.OptionalValueType(str), + 'statistical_summary': parsing.OptionalValueType(str) } - def __init__(self, name=None, state=None, orders=None, messages=None, results=None, summary=None, **kwargs): + def __init__(self, name=None, state=None, orders=None, messages=None, results=None, summary=None, statistical_summary=None, **kwargs): """ Constructor. """ self.name = '' @@ -50,5 +51,5 @@ class GamePhaseData(Jsonable): self.messages = {} self.summary = None - super(GamePhaseData, self).__init__(name=name, state=state, orders=orders, results=results, messages=messages, summary=summary, **kwargs) - + self.statistical_summary = None + super(GamePhaseData, self).__init__(name=name, state=state, orders=orders, results=results, messages=messages, summary=summary, statistical_summary=statistical_summary, **kwargs) diff --git a/lm_game.py b/lm_game.py index 33b4f6f..ce122fa 100644 --- a/lm_game.py +++ b/lm_game.py @@ -25,6 +25,7 @@ from ai_diplomacy.negotiations import conduct_negotiations from ai_diplomacy.planning import planning_phase from ai_diplomacy.game_history import GameHistory from ai_diplomacy.agent import DiplomacyAgent +import ai_diplomacy.narrative dotenv.load_dotenv()