all client types can be specified with a prefix in front of the model id

This commit is contained in:
sam-paech 2025-07-03 07:41:29 +10:00
parent 1cb24f1884
commit f29ac9c1c5

View file

@ -13,6 +13,8 @@ from dotenv import load_dotenv
from openai import AsyncOpenAI
from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
from anthropic import AsyncAnthropic
import asyncio
import requests
import google.generativeai as genai
from together import AsyncTogether
@ -800,12 +802,29 @@ class BaseModelClient:
class OpenAIClient(BaseModelClient):
"""
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
Async client for OpenAI-compatible chat-completion endpoints.
Accepts an optional base_url override.
"""
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
def __init__(
self,
model_name: str,
prompts_dir: Optional[str] = None,
base_url: Optional[str] = None,
):
super().__init__(model_name, prompts_dir=prompts_dir)
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url='https://gadgets-become-throughout-kenneth.trycloudflare.com/v1')
# Allow env var or constructor arg
self.base_url = (
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
self.client = AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=self.base_url,
)
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
# Updated to new API format
@ -1224,44 +1243,188 @@ class TogetherAIClient(BaseModelClient):
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
class RequestsOpenAIClient(BaseModelClient):
"""
Synchronous `requests`-based client for any OpenAI-compatible API.
Wrapped in `async` via `asyncio.to_thread` so it can live beside
the aiohttp / SDK implementations without changing call-sites.
"""
def __init__(
self,
model_name: str,
prompts_dir: Optional[str] = None,
base_url: Optional[str] = None,
):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable is required")
# Constructor arg > env var > default official endpoint
self.base_url = (
(base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1")
.rstrip("/")
)
# Pre-build the endpoint used by Chat Completions
self.endpoint = f"{self.base_url}/chat/completions"
# ------------------------------------------------------------------ #
# internal synchronous helper
# ------------------------------------------------------------------ #
def _post_sync(self, payload: dict) -> dict:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
r.raise_for_status()
return r.json()
# ------------------------------------------------------------------ #
# async public API
# ------------------------------------------------------------------ #
async def generate_response(
self,
prompt: str,
temperature: float = 0.0,
inject_random_seed: bool = True,
) -> str:
system_prompt_content = self.system_prompt
if inject_random_seed:
system_prompt_content = f"{generate_random_seed()}\n\n{system_prompt_content}"
# Standard OpenAI JSON payload
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
],
"temperature": temperature,
"max_tokens": self.max_tokens,
}
# Run blocking call in threadpool
loop = asyncio.get_running_loop()
try:
response_data = await loop.run_in_executor(None, self._post_sync, payload)
except Exception as e:
logger.error(f"[{self.model_name}] HTTP error: {e}", exc_info=True)
return ""
# Extract assistant content
try:
return (
response_data["choices"][0]["message"]["content"].strip()
if response_data.get("choices")
else ""
)
except (KeyError, IndexError, TypeError) as e:
logger.error(f"[{self.model_name}] Bad response format: {e}")
return ""
##############################################################################
# 3) Factory to Load Model Client
##############################################################################
def _split_model_and_base(spec: str) -> Tuple[str, Optional[str]]:
"""
Accepts strings like 'gpt-4o' -> ('gpt-4o', None)
'gpt-4o@http://foo' -> ('gpt-4o', 'http://foo')
"""
if "@" in spec:
model, base = spec.split("@", 1)
return model, base.rstrip("/")
return spec, None
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
"""
Returns the appropriate LLM client for a given model_id string.
Args:
model_id: The model identifier
prompts_dir: Optional path to the prompts directory.
Example usage:
client = load_model_client("claude-3-5-sonnet-20241022")
Factory supporting:
'<prefix>:<model>[@base_url]'
'<model>[@base_url]' (heuristic fallback)
Recognised prefixes client class
openai, oai OpenAIClient
requests, req RequestsOpenAIClient
responses, oai-resp OpenAIResponsesClient
claude ClaudeClient
gemini GeminiClient
deepseek DeepSeekClient
openrouter, or OpenRouterClient
together TogetherAIClient
"""
# Basic pattern matching or direct mapping
lower_id = model_id.lower()
# Check for o3-pro model specifically - it needs the Responses API
# First: pull any '@base_url' suffix off the full string
model_part, explicit_base = _split_model_and_base(model_id)
# Next: check for explicit '<prefix>:...' pattern
if ":" in model_part:
prefix, rest = model_part.split(":", 1)
prefix = prefix.lower()
model_name, alt_base = _split_model_and_base(rest)
base_url = explicit_base or alt_base # explicit wins
if prefix in {"openai", "oai"}:
return OpenAIClient(model_name, prompts_dir, base_url)
if prefix in {"requests", "req"}:
return RequestsOpenAIClient(model_name, prompts_dir, base_url)
if prefix in {"responses", "oai-resp", "openai-responses"}:
return OpenAIResponsesClient(model_name, prompts_dir)
if prefix == "claude":
return ClaudeClient(model_name, prompts_dir)
if prefix == "gemini":
return GeminiClient(model_name, prompts_dir)
if prefix == "deepseek":
return DeepSeekClient(model_name, prompts_dir)
if prefix in {"openrouter", "or"}:
return OpenRouterClient(model_name, prompts_dir)
if prefix == "together":
return TogetherAIClient(model_name, prompts_dir)
logger.warning(f"[load_model_client] Unrecognised prefix '{prefix}', falling back.")
model_part, explicit_base = rest, base_url # keep analysing without prefix
# ------------------------------------------------------------------
# Heuristic fallback (unchanged apart from base_url support)
# ------------------------------------------------------------------
model_name, base_url = _split_model_and_base(model_part)
base_url = explicit_base or base_url
lower_id = model_name.lower()
if lower_id == "o3-pro":
return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir)
# Check for OpenRouter first to handle prefixed models like openrouter-deepseek
elif model_id.startswith("together-"):
actual_model_name = model_id.split("together-", 1)[1]
logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})")
return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir)
elif "openrouter" in model_id.lower(): # or "/" in model_id: # More general check for OpenRouterClient(model_id)
return OpenRouterClient(model_id, prompts_dir=prompts_dir)
elif "claude" in lower_id:
return ClaudeClient(model_id, prompts_dir=prompts_dir)
elif "gemini" in lower_id:
return GeminiClient(model_id, prompts_dir=prompts_dir)
elif "deepseek" in lower_id:
return DeepSeekClient(model_id, prompts_dir=prompts_dir)
else:
# Default to OpenAI (for models like o3-mini, gpt-4o, etc.)
return OpenAIClient(model_id, prompts_dir=prompts_dir)
return OpenAIResponsesClient(model_name, prompts_dir)
if model_name.startswith("together-"):
return TogetherAIClient(model_name.split("together-", 1)[1], prompts_dir)
if "openrouter" in lower_id:
return OpenRouterClient(model_name, prompts_dir)
if "claude" in lower_id:
return ClaudeClient(model_name, prompts_dir)
if "gemini" in lower_id:
return GeminiClient(model_name, prompts_dir)
if "deepseek" in lower_id:
return DeepSeekClient(model_name, prompts_dir)
# Default to OpenAI-compatible async client
return OpenAIClient(model_name, prompts_dir, base_url)
##############################################################################