From 77e7921b9c9de2b873454c8af4680ef92738edfb Mon Sep 17 00:00:00 2001 From: AlxAI Date: Sat, 14 Jun 2025 17:28:01 -0400 Subject: [PATCH] Support for o3-pro with openai responses api --- ai_diplomacy/clients.py | 110 +++++++++++++++++++++++++++++++++++++++- ai_diplomacy/utils.py | 10 ++-- 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index 29b327c..18fc65e 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -4,6 +4,7 @@ from json import JSONDecodeError import re import logging import ast # For literal_eval in JSON fallback parsing +import aiohttp # For direct HTTP requests to Responses API from typing import List, Dict, Optional, Any, Tuple from dotenv import load_dotenv @@ -922,6 +923,107 @@ class DeepSeekClient(BaseModelClient): return "" +class OpenAIResponsesClient(BaseModelClient): + """ + For OpenAI o3-pro model using the new Responses API endpoint. + This client makes direct HTTP requests to the v1/responses endpoint. + """ + + def __init__(self, model_name: str): + super().__init__(model_name) + self.api_key = os.environ.get("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OPENAI_API_KEY environment variable is required") + self.base_url = "https://api.openai.com/v1/responses" + logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client") + + async def generate_response(self, prompt: str) -> str: + try: + # The Responses API uses a different format than chat completions + # Combine system prompt and user prompt into a single input + full_prompt = f"{self.system_prompt}\n\n{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:" + + # Prepare the request payload + payload = { + "model": self.model_name, + "input": full_prompt + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + # Make the API call using aiohttp + async with aiohttp.ClientSession() as session: + async with session.post(self.base_url, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"[{self.model_name}] API error (status {response.status}): {error_text}" + ) + return "" + + response_data = await response.json() + + # Extract the text from the nested response structure + # The text is in output[1].content[0].text based on the response + try: + outputs = response_data.get("output", []) + if len(outputs) < 2: + logger.warning( + f"[{self.model_name}] Unexpected output structure. Full response: {response_data}" + ) + return "" + + # The message is typically in the second output item + message_output = outputs[1] + if message_output.get("type") != "message": + logger.warning( + f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}" + ) + return "" + + content_list = message_output.get("content", []) + if not content_list: + logger.warning( + f"[{self.model_name}] Empty content list in message output" + ) + return "" + + # Look for the content item with type 'output_text' + text_content = "" + for content_item in content_list: + if content_item.get("type") == "output_text": + text_content = content_item.get("text", "") + break + + if not text_content: + logger.warning( + f"[{self.model_name}] No output_text found in content. Full content: {content_list}" + ) + return "" + + return text_content.strip() + + except (KeyError, IndexError, TypeError) as e: + logger.error( + f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}" + ) + return "" + + except aiohttp.ClientError as e: + logger.error( + f"[{self.model_name}] HTTP client error in generate_response: {e}" + ) + return "" + except Exception as e: + logger.error( + f"[{self.model_name}] Unexpected error in generate_response: {e}" + ) + return "" + + class OpenRouterClient(BaseModelClient): """ For OpenRouter models, with default being 'openrouter/quasar-alpha' @@ -1007,8 +1109,12 @@ def load_model_client(model_id: str) -> BaseModelClient: """ # Basic pattern matching or direct mapping lower_id = model_id.lower() + + # Check for o3-pro model specifically - it needs the Responses API + if lower_id == "o3-pro": + return OpenAIResponsesClient(model_id) # Check for OpenRouter first to handle prefixed models like openrouter-deepseek - if "openrouter" in lower_id or "quasar" in lower_id: + elif "openrouter" in lower_id or "quasar" in lower_id: return OpenRouterClient(model_id) elif "claude" in lower_id: return ClaudeClient(model_id) @@ -1017,7 +1123,7 @@ def load_model_client(model_id: str) -> BaseModelClient: elif "deepseek" in lower_id: return DeepSeekClient(model_id) else: - # Default to OpenAI + # Default to OpenAI (for models like o3-mini, gpt-4o, etc.) return OpenAIClient(model_id) diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index 81c2640..76c265b 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -34,20 +34,19 @@ def assign_models_to_powers() -> Dict[str, str]: """ # POWER MODELS - """ + return { "AUSTRIA": "o3", "ENGLAND": "gpt-4.1-2025-04-14", - "FRANCE": "o4-mini", + "FRANCE": "o3-pro", "GERMANY": "gpt-4o", - "ITALY": "gpt-4.1-2025-04-14", + "ITALY": "o3", "RUSSIA": "gpt-4o", "TURKEY": "o4-mini", } - """ # TEST MODELS - + """ return { "AUSTRIA": "openrouter-google/gemini-2.5-flash-preview", "ENGLAND": "openrouter-google/gemini-2.5-flash-preview", @@ -57,6 +56,7 @@ def assign_models_to_powers() -> Dict[str, str]: "RUSSIA": "openrouter-google/gemini-2.5-flash-preview", "TURKEY": "openrouter-google/gemini-2.5-flash-preview", } + """ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]: