mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Support for o3-pro with openai responses api
This commit is contained in:
parent
fbd92d91ba
commit
77e7921b9c
2 changed files with 113 additions and 7 deletions
|
|
@ -4,6 +4,7 @@ from json import JSONDecodeError
|
|||
import re
|
||||
import logging
|
||||
import ast # For literal_eval in JSON fallback parsing
|
||||
import aiohttp # For direct HTTP requests to Responses API
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -922,6 +923,107 @@ class DeepSeekClient(BaseModelClient):
|
|||
return ""
|
||||
|
||||
|
||||
class OpenAIResponsesClient(BaseModelClient):
|
||||
"""
|
||||
For OpenAI o3-pro model using the new Responses API endpoint.
|
||||
This client makes direct HTTP requests to the v1/responses endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
self.base_url = "https://api.openai.com/v1/responses"
|
||||
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client")
|
||||
|
||||
async def generate_response(self, prompt: str) -> str:
|
||||
try:
|
||||
# The Responses API uses a different format than chat completions
|
||||
# Combine system prompt and user prompt into a single input
|
||||
full_prompt = f"{self.system_prompt}\n\n{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": full_prompt
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
# Make the API call using aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"[{self.model_name}] API error (status {response.status}): {error_text}"
|
||||
)
|
||||
return ""
|
||||
|
||||
response_data = await response.json()
|
||||
|
||||
# Extract the text from the nested response structure
|
||||
# The text is in output[1].content[0].text based on the response
|
||||
try:
|
||||
outputs = response_data.get("output", [])
|
||||
if len(outputs) < 2:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
|
||||
)
|
||||
return ""
|
||||
|
||||
# The message is typically in the second output item
|
||||
message_output = outputs[1]
|
||||
if message_output.get("type") != "message":
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
|
||||
)
|
||||
return ""
|
||||
|
||||
content_list = message_output.get("content", [])
|
||||
if not content_list:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty content list in message output"
|
||||
)
|
||||
return ""
|
||||
|
||||
# Look for the content item with type 'output_text'
|
||||
text_content = ""
|
||||
for content_item in content_list:
|
||||
if content_item.get("type") == "output_text":
|
||||
text_content = content_item.get("text", "")
|
||||
break
|
||||
|
||||
if not text_content:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
|
||||
)
|
||||
return ""
|
||||
|
||||
return text_content.strip()
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
|
||||
)
|
||||
return ""
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] HTTP client error in generate_response: {e}"
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
class OpenRouterClient(BaseModelClient):
|
||||
"""
|
||||
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
||||
|
|
@ -1007,8 +1109,12 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
|||
"""
|
||||
# Basic pattern matching or direct mapping
|
||||
lower_id = model_id.lower()
|
||||
|
||||
# Check for o3-pro model specifically - it needs the Responses API
|
||||
if lower_id == "o3-pro":
|
||||
return OpenAIResponsesClient(model_id)
|
||||
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
||||
if "openrouter" in lower_id or "quasar" in lower_id:
|
||||
elif "openrouter" in lower_id or "quasar" in lower_id:
|
||||
return OpenRouterClient(model_id)
|
||||
elif "claude" in lower_id:
|
||||
return ClaudeClient(model_id)
|
||||
|
|
@ -1017,7 +1123,7 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
|||
elif "deepseek" in lower_id:
|
||||
return DeepSeekClient(model_id)
|
||||
else:
|
||||
# Default to OpenAI
|
||||
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
||||
return OpenAIClient(model_id)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,20 +34,19 @@ def assign_models_to_powers() -> Dict[str, str]:
|
|||
"""
|
||||
|
||||
# POWER MODELS
|
||||
"""
|
||||
|
||||
return {
|
||||
"AUSTRIA": "o3",
|
||||
"ENGLAND": "gpt-4.1-2025-04-14",
|
||||
"FRANCE": "o4-mini",
|
||||
"FRANCE": "o3-pro",
|
||||
"GERMANY": "gpt-4o",
|
||||
"ITALY": "gpt-4.1-2025-04-14",
|
||||
"ITALY": "o3",
|
||||
"RUSSIA": "gpt-4o",
|
||||
"TURKEY": "o4-mini",
|
||||
}
|
||||
"""
|
||||
|
||||
# TEST MODELS
|
||||
|
||||
"""
|
||||
return {
|
||||
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview",
|
||||
"ENGLAND": "openrouter-google/gemini-2.5-flash-preview",
|
||||
|
|
@ -57,6 +56,7 @@ def assign_models_to_powers() -> Dict[str, str]:
|
|||
"RUSSIA": "openrouter-google/gemini-2.5-flash-preview",
|
||||
"TURKEY": "openrouter-google/gemini-2.5-flash-preview",
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue