mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-26 17:13:19 +00:00
Support for o3-pro with openai responses api
This commit is contained in:
parent
fbd92d91ba
commit
77e7921b9c
2 changed files with 113 additions and 7 deletions
|
|
@ -4,6 +4,7 @@ from json import JSONDecodeError
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import ast # For literal_eval in JSON fallback parsing
|
import ast # For literal_eval in JSON fallback parsing
|
||||||
|
import aiohttp # For direct HTTP requests to Responses API
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -922,6 +923,107 @@ class DeepSeekClient(BaseModelClient):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsesClient(BaseModelClient):
|
||||||
|
"""
|
||||||
|
For OpenAI o3-pro model using the new Responses API endpoint.
|
||||||
|
This client makes direct HTTP requests to the v1/responses endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
super().__init__(model_name)
|
||||||
|
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||||
|
self.base_url = "https://api.openai.com/v1/responses"
|
||||||
|
logger.info(f"[{self.model_name}] Initialized OpenAI Responses API client")
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str) -> str:
|
||||||
|
try:
|
||||||
|
# The Responses API uses a different format than chat completions
|
||||||
|
# Combine system prompt and user prompt into a single input
|
||||||
|
full_prompt = f"{self.system_prompt}\n\n{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||||
|
|
||||||
|
# Prepare the request payload
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": full_prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make the API call using aiohttp
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
f"[{self.model_name}] API error (status {response.status}): {error_text}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
# Extract the text from the nested response structure
|
||||||
|
# The text is in output[1].content[0].text based on the response
|
||||||
|
try:
|
||||||
|
outputs = response_data.get("output", [])
|
||||||
|
if len(outputs) < 2:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# The message is typically in the second output item
|
||||||
|
message_output = outputs[1]
|
||||||
|
if message_output.get("type") != "message":
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
content_list = message_output.get("content", [])
|
||||||
|
if not content_list:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.model_name}] Empty content list in message output"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Look for the content item with type 'output_text'
|
||||||
|
text_content = ""
|
||||||
|
for content_item in content_list:
|
||||||
|
if content_item.get("type") == "output_text":
|
||||||
|
text_content = content_item.get("text", "")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not text_content:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return text_content.strip()
|
||||||
|
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.model_name}] HTTP client error in generate_response: {e}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterClient(BaseModelClient):
|
class OpenRouterClient(BaseModelClient):
|
||||||
"""
|
"""
|
||||||
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
For OpenRouter models, with default being 'openrouter/quasar-alpha'
|
||||||
|
|
@ -1007,8 +1109,12 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
||||||
"""
|
"""
|
||||||
# Basic pattern matching or direct mapping
|
# Basic pattern matching or direct mapping
|
||||||
lower_id = model_id.lower()
|
lower_id = model_id.lower()
|
||||||
|
|
||||||
|
# Check for o3-pro model specifically - it needs the Responses API
|
||||||
|
if lower_id == "o3-pro":
|
||||||
|
return OpenAIResponsesClient(model_id)
|
||||||
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
||||||
if "openrouter" in lower_id or "quasar" in lower_id:
|
elif "openrouter" in lower_id or "quasar" in lower_id:
|
||||||
return OpenRouterClient(model_id)
|
return OpenRouterClient(model_id)
|
||||||
elif "claude" in lower_id:
|
elif "claude" in lower_id:
|
||||||
return ClaudeClient(model_id)
|
return ClaudeClient(model_id)
|
||||||
|
|
@ -1017,7 +1123,7 @@ def load_model_client(model_id: str) -> BaseModelClient:
|
||||||
elif "deepseek" in lower_id:
|
elif "deepseek" in lower_id:
|
||||||
return DeepSeekClient(model_id)
|
return DeepSeekClient(model_id)
|
||||||
else:
|
else:
|
||||||
# Default to OpenAI
|
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
||||||
return OpenAIClient(model_id)
|
return OpenAIClient(model_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,20 +34,19 @@ def assign_models_to_powers() -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# POWER MODELS
|
# POWER MODELS
|
||||||
"""
|
|
||||||
return {
|
return {
|
||||||
"AUSTRIA": "o3",
|
"AUSTRIA": "o3",
|
||||||
"ENGLAND": "gpt-4.1-2025-04-14",
|
"ENGLAND": "gpt-4.1-2025-04-14",
|
||||||
"FRANCE": "o4-mini",
|
"FRANCE": "o3-pro",
|
||||||
"GERMANY": "gpt-4o",
|
"GERMANY": "gpt-4o",
|
||||||
"ITALY": "gpt-4.1-2025-04-14",
|
"ITALY": "o3",
|
||||||
"RUSSIA": "gpt-4o",
|
"RUSSIA": "gpt-4o",
|
||||||
"TURKEY": "o4-mini",
|
"TURKEY": "o4-mini",
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
|
|
||||||
# TEST MODELS
|
# TEST MODELS
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview",
|
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview",
|
||||||
"ENGLAND": "openrouter-google/gemini-2.5-flash-preview",
|
"ENGLAND": "openrouter-google/gemini-2.5-flash-preview",
|
||||||
|
|
@ -57,6 +56,7 @@ def assign_models_to_powers() -> Dict[str, str]:
|
||||||
"RUSSIA": "openrouter-google/gemini-2.5-flash-preview",
|
"RUSSIA": "openrouter-google/gemini-2.5-flash-preview",
|
||||||
"TURKEY": "openrouter-google/gemini-2.5-flash-preview",
|
"TURKEY": "openrouter-google/gemini-2.5-flash-preview",
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue