mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
all client types can be specified with a prefix in front of the model id
This commit is contained in:
parent
1cb24f1884
commit
f29ac9c1c5
1 changed files with 195 additions and 32 deletions
|
|
@ -13,6 +13,8 @@ from dotenv import load_dotenv
|
|||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
|
||||
from anthropic import AsyncAnthropic
|
||||
import asyncio
|
||||
import requests
|
||||
|
||||
import google.generativeai as genai
|
||||
from together import AsyncTogether
|
||||
|
|
@ -800,12 +802,29 @@ class BaseModelClient:
|
|||
|
||||
class OpenAIClient(BaseModelClient):
|
||||
"""
|
||||
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
|
||||
Async client for OpenAI-compatible chat-completion endpoints.
|
||||
Accepts an optional base_url override.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url='https://gadgets-become-throughout-kenneth.trycloudflare.com/v1')
|
||||
|
||||
# Allow env var or constructor arg
|
||||
self.base_url = (
|
||||
base_url
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
# Updated to new API format
|
||||
|
|
@ -1224,44 +1243,188 @@ class TogetherAIClient(BaseModelClient):
|
|||
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
|
||||
|
||||
|
||||
class RequestsOpenAIClient(BaseModelClient):
|
||||
"""
|
||||
Synchronous `requests`-based client for any OpenAI-compatible API.
|
||||
Wrapped in `async` via `asyncio.to_thread` so it can live beside
|
||||
the aiohttp / SDK implementations without changing call-sites.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
|
||||
# Constructor arg > env var > default official endpoint
|
||||
self.base_url = (
|
||||
(base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1")
|
||||
.rstrip("/")
|
||||
)
|
||||
|
||||
# Pre-build the endpoint used by Chat Completions
|
||||
self.endpoint = f"{self.base_url}/chat/completions"
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal synchronous helper
|
||||
# ------------------------------------------------------------------ #
|
||||
def _post_sync(self, payload: dict) -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# async public API
|
||||
# ------------------------------------------------------------------ #
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.0,
|
||||
inject_random_seed: bool = True,
|
||||
) -> str:
|
||||
system_prompt_content = self.system_prompt
|
||||
if inject_random_seed:
|
||||
system_prompt_content = f"{generate_random_seed()}\n\n{system_prompt_content}"
|
||||
|
||||
# Standard OpenAI JSON payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
# Run blocking call in threadpool
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
response_data = await loop.run_in_executor(None, self._post_sync, payload)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.model_name}] HTTP error: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
# Extract assistant content
|
||||
try:
|
||||
return (
|
||||
response_data["choices"][0]["message"]["content"].strip()
|
||||
if response_data.get("choices")
|
||||
else ""
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"[{self.model_name}] Bad response format: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
##############################################################################
|
||||
# 3) Factory to Load Model Client
|
||||
##############################################################################
|
||||
def _split_model_and_base(spec: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Accepts strings like 'gpt-4o' -> ('gpt-4o', None)
|
||||
'gpt-4o@http://foo' -> ('gpt-4o', 'http://foo')
|
||||
"""
|
||||
if "@" in spec:
|
||||
model, base = spec.split("@", 1)
|
||||
return model, base.rstrip("/")
|
||||
return spec, None
|
||||
|
||||
|
||||
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
|
||||
"""
|
||||
Returns the appropriate LLM client for a given model_id string.
|
||||
Factory supporting:
|
||||
• '<prefix>:<model>[@base_url]'
|
||||
• '<model>[@base_url]' (heuristic fallback)
|
||||
|
||||
Args:
|
||||
model_id: The model identifier
|
||||
prompts_dir: Optional path to the prompts directory.
|
||||
|
||||
Example usage:
|
||||
client = load_model_client("claude-3-5-sonnet-20241022")
|
||||
Recognised prefixes → client class
|
||||
openai, oai → OpenAIClient
|
||||
requests, req → RequestsOpenAIClient
|
||||
responses, oai-resp → OpenAIResponsesClient
|
||||
claude → ClaudeClient
|
||||
gemini → GeminiClient
|
||||
deepseek → DeepSeekClient
|
||||
openrouter, or → OpenRouterClient
|
||||
together → TogetherAIClient
|
||||
"""
|
||||
# Basic pattern matching or direct mapping
|
||||
lower_id = model_id.lower()
|
||||
|
||||
# Check for o3-pro model specifically - it needs the Responses API
|
||||
# First: pull any '@base_url' suffix off the full string
|
||||
model_part, explicit_base = _split_model_and_base(model_id)
|
||||
|
||||
# Next: check for explicit '<prefix>:...' pattern
|
||||
if ":" in model_part:
|
||||
prefix, rest = model_part.split(":", 1)
|
||||
prefix = prefix.lower()
|
||||
model_name, alt_base = _split_model_and_base(rest)
|
||||
base_url = explicit_base or alt_base # explicit wins
|
||||
|
||||
if prefix in {"openai", "oai"}:
|
||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
if prefix in {"requests", "req"}:
|
||||
return RequestsOpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
if prefix in {"responses", "oai-resp", "openai-responses"}:
|
||||
return OpenAIResponsesClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "claude":
|
||||
return ClaudeClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "gemini":
|
||||
return GeminiClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "deepseek":
|
||||
return DeepSeekClient(model_name, prompts_dir)
|
||||
|
||||
if prefix in {"openrouter", "or"}:
|
||||
return OpenRouterClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "together":
|
||||
return TogetherAIClient(model_name, prompts_dir)
|
||||
|
||||
logger.warning(f"[load_model_client] Unrecognised prefix '{prefix}', falling back.")
|
||||
model_part, explicit_base = rest, base_url # keep analysing without prefix
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Heuristic fallback (unchanged apart from base_url support)
|
||||
# ------------------------------------------------------------------
|
||||
model_name, base_url = _split_model_and_base(model_part)
|
||||
base_url = explicit_base or base_url
|
||||
|
||||
lower_id = model_name.lower()
|
||||
|
||||
if lower_id == "o3-pro":
|
||||
return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir)
|
||||
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
|
||||
elif model_id.startswith("together-"):
|
||||
actual_model_name = model_id.split("together-", 1)[1]
|
||||
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
|
||||
return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir)
|
||||
elif "openrouter" in model_id.lower(): # or "/" in model_id: # More general check for OpenRouterClient(model_id)
|
||||
return OpenRouterClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "claude" in lower_id:
|
||||
return ClaudeClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "gemini" in lower_id:
|
||||
return GeminiClient(model_id, prompts_dir=prompts_dir)
|
||||
elif "deepseek" in lower_id:
|
||||
return DeepSeekClient(model_id, prompts_dir=prompts_dir)
|
||||
else:
|
||||
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
|
||||
return OpenAIClient(model_id, prompts_dir=prompts_dir)
|
||||
return OpenAIResponsesClient(model_name, prompts_dir)
|
||||
|
||||
if model_name.startswith("together-"):
|
||||
return TogetherAIClient(model_name.split("together-", 1)[1], prompts_dir)
|
||||
|
||||
if "openrouter" in lower_id:
|
||||
return OpenRouterClient(model_name, prompts_dir)
|
||||
|
||||
if "claude" in lower_id:
|
||||
return ClaudeClient(model_name, prompts_dir)
|
||||
|
||||
if "gemini" in lower_id:
|
||||
return GeminiClient(model_name, prompts_dir)
|
||||
|
||||
if "deepseek" in lower_id:
|
||||
return DeepSeekClient(model_name, prompts_dir)
|
||||
|
||||
# Default to OpenAI-compatible async client
|
||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue