diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index 624f08d..3eed399 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -13,6 +13,8 @@ from dotenv import load_dotenv from openai import AsyncOpenAI from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity from anthropic import AsyncAnthropic +import asyncio +import requests import google.generativeai as genai from together import AsyncTogether @@ -800,12 +802,29 @@ class BaseModelClient: class OpenAIClient(BaseModelClient): """ - For 'o3-mini', 'gpt-4o', or other OpenAI model calls. + Async client for OpenAI-compatible chat-completion endpoints. + Accepts an optional base_url override. """ - def __init__(self, model_name: str, prompts_dir: Optional[str] = None): + def __init__( + self, + model_name: str, + prompts_dir: Optional[str] = None, + base_url: Optional[str] = None, + ): super().__init__(model_name, prompts_dir=prompts_dir) - self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url='https://gadgets-become-throughout-kenneth.trycloudflare.com/v1') + + # Allow env var or constructor arg + self.base_url = ( + base_url + or os.environ.get("OPENAI_BASE_URL") + or "https://api.openai.com/v1" + ) + + self.client = AsyncOpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + base_url=self.base_url, + ) async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: # Updated to new API format @@ -1224,44 +1243,188 @@ class TogetherAIClient(BaseModelClient): return f"Error: Unexpected error - {str(e)}" # Return a string with error info +class RequestsOpenAIClient(BaseModelClient): + """ + Synchronous `requests`-based client for any OpenAI-compatible API. + Wrapped in `async` via `asyncio.to_thread` so it can live beside + the aiohttp / SDK implementations without changing call-sites. + """ + + def __init__( + self, + model_name: str, + prompts_dir: Optional[str] = None, + base_url: Optional[str] = None, + ): + super().__init__(model_name, prompts_dir=prompts_dir) + + self.api_key = os.environ.get("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OPENAI_API_KEY environment variable is required") + + # Constructor arg > env var > default official endpoint + self.base_url = ( + (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1") + .rstrip("/") + ) + + # Pre-build the endpoint used by Chat Completions + self.endpoint = f"{self.base_url}/chat/completions" + + # ------------------------------------------------------------------ # + # internal synchronous helper + # ------------------------------------------------------------------ # + def _post_sync(self, payload: dict) -> dict: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60) + r.raise_for_status() + return r.json() + + # ------------------------------------------------------------------ # + # async public API + # ------------------------------------------------------------------ # + async def generate_response( + self, + prompt: str, + temperature: float = 0.0, + inject_random_seed: bool = True, + ) -> str: + system_prompt_content = self.system_prompt + if inject_random_seed: + system_prompt_content = f"{generate_random_seed()}\n\n{system_prompt_content}" + + # Standard OpenAI JSON payload + payload = { + "model": self.model_name, + "messages": [ + {"role": "system", "content": system_prompt_content}, + {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"}, + ], + "temperature": temperature, + "max_tokens": self.max_tokens, + } + + # Run blocking call in threadpool + loop = asyncio.get_running_loop() + try: + response_data = await loop.run_in_executor(None, self._post_sync, payload) + except Exception as e: + logger.error(f"[{self.model_name}] HTTP error: {e}", exc_info=True) + return "" + + # Extract assistant content + try: + return ( + response_data["choices"][0]["message"]["content"].strip() + if response_data.get("choices") + else "" + ) + except (KeyError, IndexError, TypeError) as e: + logger.error(f"[{self.model_name}] Bad response format: {e}") + return "" + + ############################################################################## # 3) Factory to Load Model Client ############################################################################## +def _split_model_and_base(spec: str) -> Tuple[str, Optional[str]]: + """ + Accepts strings like 'gpt-4o' -> ('gpt-4o', None) + 'gpt-4o@http://foo' -> ('gpt-4o', 'http://foo') + """ + if "@" in spec: + model, base = spec.split("@", 1) + return model, base.rstrip("/") + return spec, None def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient: """ - Returns the appropriate LLM client for a given model_id string. - - Args: - model_id: The model identifier - prompts_dir: Optional path to the prompts directory. - - Example usage: - client = load_model_client("claude-3-5-sonnet-20241022") + Factory supporting: + • ':[@base_url]' + • '[@base_url]' (heuristic fallback) + + Recognised prefixes → client class + openai, oai → OpenAIClient + requests, req → RequestsOpenAIClient + responses, oai-resp → OpenAIResponsesClient + claude → ClaudeClient + gemini → GeminiClient + deepseek → DeepSeekClient + openrouter, or → OpenRouterClient + together → TogetherAIClient """ - # Basic pattern matching or direct mapping - lower_id = model_id.lower() - - # Check for o3-pro model specifically - it needs the Responses API + + # First: pull any '@base_url' suffix off the full string + model_part, explicit_base = _split_model_and_base(model_id) + + # Next: check for explicit ':...' pattern + if ":" in model_part: + prefix, rest = model_part.split(":", 1) + prefix = prefix.lower() + model_name, alt_base = _split_model_and_base(rest) + base_url = explicit_base or alt_base # explicit wins + + if prefix in {"openai", "oai"}: + return OpenAIClient(model_name, prompts_dir, base_url) + + if prefix in {"requests", "req"}: + return RequestsOpenAIClient(model_name, prompts_dir, base_url) + + if prefix in {"responses", "oai-resp", "openai-responses"}: + return OpenAIResponsesClient(model_name, prompts_dir) + + if prefix == "claude": + return ClaudeClient(model_name, prompts_dir) + + if prefix == "gemini": + return GeminiClient(model_name, prompts_dir) + + if prefix == "deepseek": + return DeepSeekClient(model_name, prompts_dir) + + if prefix in {"openrouter", "or"}: + return OpenRouterClient(model_name, prompts_dir) + + if prefix == "together": + return TogetherAIClient(model_name, prompts_dir) + + logger.warning(f"[load_model_client] Unrecognised prefix '{prefix}', falling back.") + model_part, explicit_base = rest, base_url # keep analysing without prefix + + # ------------------------------------------------------------------ + # Heuristic fallback (unchanged apart from base_url support) + # ------------------------------------------------------------------ + model_name, base_url = _split_model_and_base(model_part) + base_url = explicit_base or base_url + + lower_id = model_name.lower() + if lower_id == "o3-pro": - return OpenAIResponsesClient(model_id, prompts_dir=prompts_dir) - # Check for OpenRouter first to handle prefixed models like openrouter-deepseek - elif model_id.startswith("together-"): - actual_model_name = model_id.split("together-", 1)[1] - logger.info(f"Loading TogetherAI client for model: {actual_model_name} (original ID: {model_id})") - return TogetherAIClient(actual_model_name, prompts_dir=prompts_dir) - elif "openrouter" in model_id.lower(): # or "/" in model_id: # More general check for OpenRouterClient(model_id) - return OpenRouterClient(model_id, prompts_dir=prompts_dir) - elif "claude" in lower_id: - return ClaudeClient(model_id, prompts_dir=prompts_dir) - elif "gemini" in lower_id: - return GeminiClient(model_id, prompts_dir=prompts_dir) - elif "deepseek" in lower_id: - return DeepSeekClient(model_id, prompts_dir=prompts_dir) - else: - # Default to OpenAI (for models like o3-mini, gpt-4o, etc.) - return OpenAIClient(model_id, prompts_dir=prompts_dir) + return OpenAIResponsesClient(model_name, prompts_dir) + + if model_name.startswith("together-"): + return TogetherAIClient(model_name.split("together-", 1)[1], prompts_dir) + + if "openrouter" in lower_id: + return OpenRouterClient(model_name, prompts_dir) + + if "claude" in lower_id: + return ClaudeClient(model_name, prompts_dir) + + if "gemini" in lower_id: + return GeminiClient(model_name, prompts_dir) + + if "deepseek" in lower_id: + return DeepSeekClient(model_name, prompts_dir) + + # Default to OpenAI-compatible async client + return OpenAIClient(model_name, prompts_dir, base_url) + + ##############################################################################