allow specification of base url + api key for each model

This commit is contained in:
sam-paech 2025-07-03 08:55:10 +10:00
parent f29ac9c1c5
commit 4fc1f370be
3 changed files with 177 additions and 135 deletions

View file

@ -273,6 +273,31 @@ python lm_game.py --run_dir results/game_run_004 \
python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants
```
### Setting `--models` (quick guide)
* Pass **one comma-separated list of up to seven model IDs** in this fixed order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY.
* **Model-ID syntax**
```
<prefix:>model[@base_url][#api_key]
```
* `prefix:` optional client (`openai`, `requests`, `claude`, `together`, …).
* `@base_url` hit a proxy / alt endpoint.
* `#api_key` inline key (overrides env vars).
* **Examples**
```bash
# gpt-4o on openrouter for all powers:
--models "openrouter:gpt-4o"
# custom URL+apikey for Austria only:
--models "openai:llama-3.2-3b@http://localhost:8000#myapikey,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o"
```
### Running Batch Experiments with **`experiment_runner.py`**
`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify.

View file

@ -6,7 +6,7 @@ import logging
import ast # For literal_eval in JSON fallback parsing
import aiohttp # For direct HTTP requests to Responses API
from typing import List, Dict, Optional, Any, Tuple
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
from dotenv import load_dotenv
# Use Async versions of clients
@ -801,41 +801,42 @@ class BaseModelClient:
class OpenAIClient(BaseModelClient):
"""
Async client for OpenAI-compatible chat-completion endpoints.
Accepts an optional base_url override.
"""
"""Async client for OpenAI-compatible chat-completion endpoints."""
def __init__(
self,
model_name: str,
prompts_dir: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
super().__init__(model_name, prompts_dir=prompts_dir)
# Allow env var or constructor arg
self.base_url = (
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
self.client = AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=self.base_url,
)
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
# Updated to new API format
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
async def generate_response(
self,
prompt: str,
temperature: float = 0.0,
inject_random_seed: bool = True,
) -> str:
try:
# Append the call to action to the user's prompt
prompt_with_cta = prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
system_prompt_content = self.system_prompt
if inject_random_seed:
random_seed = generate_random_seed()
system_prompt_content = f"{random_seed}\n\n{self.system_prompt}"
system_prompt_content = (
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
response = await self.client.chat.completions.create(
model=self.model_name,
@ -846,21 +847,18 @@ class OpenAIClient(BaseModelClient):
temperature=temperature,
max_tokens=self.max_tokens,
)
if not response or not hasattr(response, "choices") or not response.choices:
logger.warning(
f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty."
)
if not response or not response.choices:
logger.warning(f"[{self.model_name}] Empty result in generate_response.")
return ""
return response.choices[0].message.content.strip()
except json.JSONDecodeError as json_err:
logger.error(
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
)
logger.error(f"[{self.model_name}] JSON decode error: {json_err}")
return ""
except Exception as e:
logger.error(
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
logger.error(f"[{self.model_name}] Unexpected error: {e}", exc_info=True)
return ""
@ -1243,11 +1241,15 @@ class TogetherAIClient(BaseModelClient):
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
##############################################################################
# RequestsOpenAIClient sync requests, wrapped async (original + api_key)
##############################################################################
import requests, asyncio
class RequestsOpenAIClient(BaseModelClient):
"""
Synchronous `requests`-based client for any OpenAI-compatible API.
Wrapped in `async` via `asyncio.to_thread` so it can live beside
the aiohttp / SDK implementations without changing call-sites.
Wrapped in `asyncio.to_thread` so call-sites remain async.
"""
def __init__(
@ -1255,25 +1257,23 @@ class RequestsOpenAIClient(BaseModelClient):
model_name: str,
prompts_dir: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("OPENAI_API_KEY")
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable is required")
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
# Constructor arg > env var > default official endpoint
self.base_url = (
(base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1")
.rstrip("/")
)
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
).rstrip("/")
# Pre-build the endpoint used by Chat Completions
self.endpoint = f"{self.base_url}/chat/completions"
# ------------------------------------------------------------------ #
# internal synchronous helper
# ------------------------------------------------------------------ #
# ---------------- internal blocking helper ---------------- #
def _post_sync(self, payload: dict) -> dict:
headers = {
"Content-Type": "application/json",
@ -1283,20 +1283,19 @@ class RequestsOpenAIClient(BaseModelClient):
r.raise_for_status()
return r.json()
# ------------------------------------------------------------------ #
# async public API
# ------------------------------------------------------------------ #
# ---------------- public async API ---------------- #
async def generate_response(
self,
prompt: str,
temperature: float = 0.0,
inject_random_seed: bool = True,
) -> str:
system_prompt_content = self.system_prompt
if inject_random_seed:
system_prompt_content = f"{generate_random_seed()}\n\n{system_prompt_content}"
system_prompt_content = (
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
# Standard OpenAI JSON payload
payload = {
"model": self.model_name,
"messages": [
@ -1307,121 +1306,137 @@ class RequestsOpenAIClient(BaseModelClient):
"max_tokens": self.max_tokens,
}
# Run blocking call in threadpool
loop = asyncio.get_running_loop()
try:
response_data = await loop.run_in_executor(None, self._post_sync, payload)
except Exception as e:
data = await loop.run_in_executor(None, self._post_sync, payload)
return data["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError, TypeError) as e:
logger.error(f"[{self.model_name}] Bad response format: {e}", exc_info=True)
return ""
except requests.RequestException as e:
logger.error(f"[{self.model_name}] HTTP error: {e}", exc_info=True)
return ""
# Extract assistant content
try:
return (
response_data["choices"][0]["message"]["content"].strip()
if response_data.get("choices")
else ""
)
except (KeyError, IndexError, TypeError) as e:
logger.error(f"[{self.model_name}] Bad response format: {e}")
except Exception as e:
logger.error(f"[{self.model_name}] Unexpected error: {e}", exc_info=True)
return ""
##############################################################################
# 3) Factory to Load Model Client
##############################################################################
def _split_model_and_base(spec: str) -> Tuple[str, Optional[str]]:
class ModelSpec(NamedTuple):
prefix: Optional[str] # 'openai', 'requests', …
model: str # 'gpt-4o'
base: Optional[str] # 'https://proxy.foo'
key: Optional[str] # 'sk-…' (may be None)
def _parse_model_spec(raw: str) -> ModelSpec:
"""
Accepts strings like 'gpt-4o' -> ('gpt-4o', None)
'gpt-4o@http://foo' -> ('gpt-4o', 'http://foo')
Splits once on '#' (API key) and once on '@' (base URL). A leading
'<prefix>:' is optional. Nothing else is interpreted.
"""
if "@" in spec:
model, base = spec.split("@", 1)
return model, base.rstrip("/")
return spec, None
raw = raw.strip()
pre_hash, _, key_part = raw.partition("#")
pre_at, _, base_part = pre_hash.partition("@")
maybe_pref, sep, model_part = pre_at.partition(":")
if sep: # explicit prefix was present
prefix, model = maybe_pref.lower(), model_part
else:
prefix, model = None, maybe_pref
return ModelSpec(prefix, model, base_part or None, key_part or None)
##############################################################################
# Factory load_model_client
##############################################################################
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
"""
Factory supporting:
'<prefix>:<model>[@base_url]'
'<model>[@base_url]' (heuristic fallback)
Recognises strings like
gpt-4o
gpt-4o@https://proxy
gpt-4o#sk-123
openai:gpt-4o@https://proxy#sk-ABC
and returns the appropriate client.
Recognised prefixes client class
openai, oai OpenAIClient
requests, req RequestsOpenAIClient
responses, oai-resp OpenAIResponsesClient
claude ClaudeClient
gemini GeminiClient
deepseek DeepSeekClient
openrouter, or OpenRouterClient
together TogetherAIClient
If a prefix is omitted the function falls back to the original
heuristic mapping exactly as before.
If an inline API-key (#…’) is present it overrides environment vars.
"""
spec = _parse_model_spec(model_id)
# First: pull any '@base_url' suffix off the full string
model_part, explicit_base = _split_model_and_base(model_id)
# Inline key overrides env; otherwise fall back as usual *per client*
inline_key = spec.key
# Next: check for explicit '<prefix>:...' pattern
if ":" in model_part:
prefix, rest = model_part.split(":", 1)
prefix = prefix.lower()
model_name, alt_base = _split_model_and_base(rest)
base_url = explicit_base or alt_base # explicit wins
# ------------------------------------------------------------------ #
# 1. Explicit prefix path #
# ------------------------------------------------------------------ #
if spec.prefix:
match spec.prefix:
case "openai" | "oai":
return OpenAIClient(
model_name = spec.model,
prompts_dir = prompts_dir,
base_url = spec.base,
api_key = inline_key,
)
case "requests" | "req":
return RequestsOpenAIClient(
model_name = spec.model,
prompts_dir = prompts_dir,
base_url = spec.base,
api_key = inline_key,
)
case "responses" | "oai-resp" | "openai-responses":
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
case "claude":
return ClaudeClient(spec.model, prompts_dir)
case "gemini":
return GeminiClient(spec.model, prompts_dir)
case "deepseek":
return DeepSeekClient(spec.model, prompts_dir)
case "openrouter" | "or":
return OpenRouterClient(spec.model, prompts_dir)
case "together":
return TogetherAIClient(spec.model, prompts_dir)
case _:
logger.warning(f"[load_model_client] Unknown prefix '{spec.prefix}', falling back to heuristic path.")
if prefix in {"openai", "oai"}:
return OpenAIClient(model_name, prompts_dir, base_url)
if prefix in {"requests", "req"}:
return RequestsOpenAIClient(model_name, prompts_dir, base_url)
if prefix in {"responses", "oai-resp", "openai-responses"}:
return OpenAIResponsesClient(model_name, prompts_dir)
if prefix == "claude":
return ClaudeClient(model_name, prompts_dir)
if prefix == "gemini":
return GeminiClient(model_name, prompts_dir)
if prefix == "deepseek":
return DeepSeekClient(model_name, prompts_dir)
if prefix in {"openrouter", "or"}:
return OpenRouterClient(model_name, prompts_dir)
if prefix == "together":
return TogetherAIClient(model_name, prompts_dir)
logger.warning(f"[load_model_client] Unrecognised prefix '{prefix}', falling back.")
model_part, explicit_base = rest, base_url # keep analysing without prefix
# ------------------------------------------------------------------
# Heuristic fallback (unchanged apart from base_url support)
# ------------------------------------------------------------------
model_name, base_url = _split_model_and_base(model_part)
base_url = explicit_base or base_url
lower_id = model_name.lower()
# ------------------------------------------------------------------ #
# 2. Heuristic fallback path (identical to the original behaviour) #
# ------------------------------------------------------------------ #
lower_id = spec.model.lower()
if lower_id == "o3-pro":
return OpenAIResponsesClient(model_name, prompts_dir)
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
if model_name.startswith("together-"):
return TogetherAIClient(model_name.split("together-", 1)[1], prompts_dir)
if spec.model.startswith("together-"):
# e.g. "together-mixtral-8x7b"
return TogetherAIClient(spec.model.split("together-", 1)[1], prompts_dir)
if "openrouter" in lower_id:
return OpenRouterClient(model_name, prompts_dir)
return OpenRouterClient(spec.model, prompts_dir)
if "claude" in lower_id:
return ClaudeClient(model_name, prompts_dir)
return ClaudeClient(spec.model, prompts_dir)
if "gemini" in lower_id:
return GeminiClient(model_name, prompts_dir)
return GeminiClient(spec.model, prompts_dir)
if "deepseek" in lower_id:
return DeepSeekClient(model_name, prompts_dir)
return DeepSeekClient(spec.model, prompts_dir)
# Default: OpenAI-compatible async client
return OpenAIClient(
model_name = spec.model,
prompts_dir = prompts_dir,
base_url = spec.base,
api_key = inline_key,
)
# Default to OpenAI-compatible async client
return OpenAIClient(model_name, prompts_dir, base_url)

View file

@ -296,9 +296,11 @@ async def initialize_new_game(
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:
game.power_model_map = dict(zip(powers_order, provided_models * 7))
else:
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Using defaults.")
game.power_model_map = assign_models_to_powers()
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.")
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.")
else:
game.power_model_map = assign_models_to_powers()