mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
allow specification of base url + api key for each model
This commit is contained in:
parent
f29ac9c1c5
commit
4fc1f370be
3 changed files with 177 additions and 135 deletions
25
README.md
25
README.md
|
|
@ -273,6 +273,31 @@ python lm_game.py --run_dir results/game_run_004 \
|
|||
python lm_game.py --run_dir results/game_run_005 --prompts_dir ./prompts/my_variants
|
||||
```
|
||||
|
||||
|
||||
### Setting `--models` (quick guide)
|
||||
|
||||
* Pass **one comma-separated list of up to seven model IDs** in this fixed order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY.
|
||||
|
||||
* **Model-ID syntax**
|
||||
|
||||
```
|
||||
<prefix:>model[@base_url][#api_key]
|
||||
```
|
||||
* `prefix:` – optional client (`openai`, `requests`, `claude`, `together`, …).
|
||||
* `@base_url` – hit a proxy / alt endpoint.
|
||||
* `#api_key` – inline key (overrides env vars).
|
||||
|
||||
* **Examples**
|
||||
|
||||
```bash
|
||||
# gpt-4o on openrouter for all powers:
|
||||
--models "openrouter:gpt-4o"
|
||||
# custom URL+apikey for Austria only:
|
||||
--models "openai:llama-3.2-3b@http://localhost:8000#myapikey,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o,openai:gpt-4o"
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Running Batch Experiments with **`experiment_runner.py`**
|
||||
|
||||
`experiment_runner.py` is a lightweight orchestrator: it spins up many `lm_game.py` runs in parallel, gathers their artefacts under one *experiment directory*, and then executes the analysis modules you specify.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import ast # For literal_eval in JSON fallback parsing
|
||||
import aiohttp # For direct HTTP requests to Responses API
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Use Async versions of clients
|
||||
|
|
@ -801,66 +801,64 @@ class BaseModelClient:
|
|||
|
||||
|
||||
class OpenAIClient(BaseModelClient):
|
||||
"""
|
||||
Async client for OpenAI-compatible chat-completion endpoints.
|
||||
Accepts an optional base_url override.
|
||||
"""
|
||||
"""Async client for OpenAI-compatible chat-completion endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
|
||||
# Allow env var or constructor arg
|
||||
self.base_url = (
|
||||
base_url
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
base_url=self.base_url,
|
||||
)
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||
# Updated to new API format
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.0,
|
||||
inject_random_seed: bool = True,
|
||||
) -> str:
|
||||
try:
|
||||
# Append the call to action to the user's prompt
|
||||
prompt_with_cta = prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
system_prompt_content = self.system_prompt
|
||||
if inject_random_seed:
|
||||
random_seed = generate_random_seed()
|
||||
system_prompt_content = f"{random_seed}\n\n{self.system_prompt}"
|
||||
system_prompt_content = (
|
||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
||||
if inject_random_seed
|
||||
else self.system_prompt
|
||||
)
|
||||
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
{"role": "user", "content": prompt_with_cta},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
if not response or not hasattr(response, "choices") or not response.choices:
|
||||
logger.warning(
|
||||
f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty."
|
||||
)
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning(f"[{self.model_name}] Empty result in generate_response.")
|
||||
return ""
|
||||
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(
|
||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] JSON decode error: {json_err}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
||||
)
|
||||
logger.error(f"[{self.model_name}] Unexpected error: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -1243,11 +1241,15 @@ class TogetherAIClient(BaseModelClient):
|
|||
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
|
||||
|
||||
|
||||
##############################################################################
|
||||
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
||||
##############################################################################
|
||||
import requests, asyncio
|
||||
|
||||
class RequestsOpenAIClient(BaseModelClient):
|
||||
"""
|
||||
Synchronous `requests`-based client for any OpenAI-compatible API.
|
||||
Wrapped in `async` via `asyncio.to_thread` so it can live beside
|
||||
the aiohttp / SDK implementations without changing call-sites.
|
||||
Wrapped in `asyncio.to_thread` so call-sites remain async.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -1255,173 +1257,186 @@ class RequestsOpenAIClient(BaseModelClient):
|
|||
model_name: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
||||
|
||||
# Constructor arg > env var > default official endpoint
|
||||
self.base_url = (
|
||||
(base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1")
|
||||
.rstrip("/")
|
||||
)
|
||||
base_url
|
||||
or os.environ.get("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
).rstrip("/")
|
||||
|
||||
# Pre-build the endpoint used by Chat Completions
|
||||
self.endpoint = f"{self.base_url}/chat/completions"
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal synchronous helper
|
||||
# ------------------------------------------------------------------ #
|
||||
# ---------------- internal blocking helper ---------------- #
|
||||
def _post_sync(self, payload: dict) -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# async public API
|
||||
# ------------------------------------------------------------------ #
|
||||
# ---------------- public async API ---------------- #
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.0,
|
||||
inject_random_seed: bool = True,
|
||||
) -> str:
|
||||
system_prompt_content = self.system_prompt
|
||||
if inject_random_seed:
|
||||
system_prompt_content = f"{generate_random_seed()}\n\n{system_prompt_content}"
|
||||
system_prompt_content = (
|
||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
||||
if inject_random_seed
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
# Standard OpenAI JSON payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt_content},
|
||||
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
||||
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
# Run blocking call in threadpool
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
response_data = await loop.run_in_executor(None, self._post_sync, payload)
|
||||
except Exception as e:
|
||||
data = await loop.run_in_executor(None, self._post_sync, payload)
|
||||
return data["choices"][0]["message"]["content"].strip()
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"[{self.model_name}] Bad response format: {e}", exc_info=True)
|
||||
return ""
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"[{self.model_name}] HTTP error: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
# Extract assistant content
|
||||
try:
|
||||
return (
|
||||
response_data["choices"][0]["message"]["content"].strip()
|
||||
if response_data.get("choices")
|
||||
else ""
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"[{self.model_name}] Bad response format: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.model_name}] Unexpected error: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
##############################################################################
|
||||
# 3) Factory to Load Model Client
|
||||
##############################################################################
|
||||
def _split_model_and_base(spec: str) -> Tuple[str, Optional[str]]:
|
||||
class ModelSpec(NamedTuple):
|
||||
prefix: Optional[str] # 'openai', 'requests', …
|
||||
model: str # 'gpt-4o'
|
||||
base: Optional[str] # 'https://proxy.foo'
|
||||
key: Optional[str] # 'sk-…' (may be None)
|
||||
|
||||
def _parse_model_spec(raw: str) -> ModelSpec:
|
||||
"""
|
||||
Accepts strings like 'gpt-4o' -> ('gpt-4o', None)
|
||||
'gpt-4o@http://foo' -> ('gpt-4o', 'http://foo')
|
||||
Splits once on '#' (API key) and once on '@' (base URL). A leading
|
||||
'<prefix>:' is optional. Nothing else is interpreted.
|
||||
"""
|
||||
if "@" in spec:
|
||||
model, base = spec.split("@", 1)
|
||||
return model, base.rstrip("/")
|
||||
return spec, None
|
||||
raw = raw.strip()
|
||||
|
||||
pre_hash, _, key_part = raw.partition("#")
|
||||
pre_at, _, base_part = pre_hash.partition("@")
|
||||
|
||||
maybe_pref, sep, model_part = pre_at.partition(":")
|
||||
if sep: # explicit prefix was present
|
||||
prefix, model = maybe_pref.lower(), model_part
|
||||
else:
|
||||
prefix, model = None, maybe_pref
|
||||
|
||||
return ModelSpec(prefix, model, base_part or None, key_part or None)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Factory – load_model_client
|
||||
##############################################################################
|
||||
def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseModelClient:
|
||||
"""
|
||||
Factory supporting:
|
||||
• '<prefix>:<model>[@base_url]'
|
||||
• '<model>[@base_url]' (heuristic fallback)
|
||||
Recognises strings like
|
||||
gpt-4o
|
||||
gpt-4o@https://proxy
|
||||
gpt-4o#sk-123
|
||||
openai:gpt-4o@https://proxy#sk-ABC
|
||||
and returns the appropriate client.
|
||||
|
||||
Recognised prefixes → client class
|
||||
openai, oai → OpenAIClient
|
||||
requests, req → RequestsOpenAIClient
|
||||
responses, oai-resp → OpenAIResponsesClient
|
||||
claude → ClaudeClient
|
||||
gemini → GeminiClient
|
||||
deepseek → DeepSeekClient
|
||||
openrouter, or → OpenRouterClient
|
||||
together → TogetherAIClient
|
||||
• If a prefix is omitted the function falls back to the original
|
||||
heuristic mapping exactly as before.
|
||||
• If an inline API-key (‘#…’) is present it overrides environment vars.
|
||||
"""
|
||||
spec = _parse_model_spec(model_id)
|
||||
|
||||
# First: pull any '@base_url' suffix off the full string
|
||||
model_part, explicit_base = _split_model_and_base(model_id)
|
||||
# Inline key overrides env; otherwise fall back as usual *per client*
|
||||
inline_key = spec.key
|
||||
|
||||
# Next: check for explicit '<prefix>:...' pattern
|
||||
if ":" in model_part:
|
||||
prefix, rest = model_part.split(":", 1)
|
||||
prefix = prefix.lower()
|
||||
model_name, alt_base = _split_model_and_base(rest)
|
||||
base_url = explicit_base or alt_base # explicit wins
|
||||
# ------------------------------------------------------------------ #
|
||||
# 1. Explicit prefix path #
|
||||
# ------------------------------------------------------------------ #
|
||||
if spec.prefix:
|
||||
match spec.prefix:
|
||||
case "openai" | "oai":
|
||||
return OpenAIClient(
|
||||
model_name = spec.model,
|
||||
prompts_dir = prompts_dir,
|
||||
base_url = spec.base,
|
||||
api_key = inline_key,
|
||||
)
|
||||
case "requests" | "req":
|
||||
return RequestsOpenAIClient(
|
||||
model_name = spec.model,
|
||||
prompts_dir = prompts_dir,
|
||||
base_url = spec.base,
|
||||
api_key = inline_key,
|
||||
)
|
||||
case "responses" | "oai-resp" | "openai-responses":
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||
case "claude":
|
||||
return ClaudeClient(spec.model, prompts_dir)
|
||||
case "gemini":
|
||||
return GeminiClient(spec.model, prompts_dir)
|
||||
case "deepseek":
|
||||
return DeepSeekClient(spec.model, prompts_dir)
|
||||
case "openrouter" | "or":
|
||||
return OpenRouterClient(spec.model, prompts_dir)
|
||||
case "together":
|
||||
return TogetherAIClient(spec.model, prompts_dir)
|
||||
case _:
|
||||
logger.warning(f"[load_model_client] Unknown prefix '{spec.prefix}', falling back to heuristic path.")
|
||||
|
||||
if prefix in {"openai", "oai"}:
|
||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
if prefix in {"requests", "req"}:
|
||||
return RequestsOpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
if prefix in {"responses", "oai-resp", "openai-responses"}:
|
||||
return OpenAIResponsesClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "claude":
|
||||
return ClaudeClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "gemini":
|
||||
return GeminiClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "deepseek":
|
||||
return DeepSeekClient(model_name, prompts_dir)
|
||||
|
||||
if prefix in {"openrouter", "or"}:
|
||||
return OpenRouterClient(model_name, prompts_dir)
|
||||
|
||||
if prefix == "together":
|
||||
return TogetherAIClient(model_name, prompts_dir)
|
||||
|
||||
logger.warning(f"[load_model_client] Unrecognised prefix '{prefix}', falling back.")
|
||||
model_part, explicit_base = rest, base_url # keep analysing without prefix
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Heuristic fallback (unchanged apart from base_url support)
|
||||
# ------------------------------------------------------------------
|
||||
model_name, base_url = _split_model_and_base(model_part)
|
||||
base_url = explicit_base or base_url
|
||||
|
||||
lower_id = model_name.lower()
|
||||
# ------------------------------------------------------------------ #
|
||||
# 2. Heuristic fallback path (identical to the original behaviour) #
|
||||
# ------------------------------------------------------------------ #
|
||||
lower_id = spec.model.lower()
|
||||
|
||||
if lower_id == "o3-pro":
|
||||
return OpenAIResponsesClient(model_name, prompts_dir)
|
||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||
|
||||
if model_name.startswith("together-"):
|
||||
return TogetherAIClient(model_name.split("together-", 1)[1], prompts_dir)
|
||||
if spec.model.startswith("together-"):
|
||||
# e.g. "together-mixtral-8x7b"
|
||||
return TogetherAIClient(spec.model.split("together-", 1)[1], prompts_dir)
|
||||
|
||||
if "openrouter" in lower_id:
|
||||
return OpenRouterClient(model_name, prompts_dir)
|
||||
return OpenRouterClient(spec.model, prompts_dir)
|
||||
|
||||
if "claude" in lower_id:
|
||||
return ClaudeClient(model_name, prompts_dir)
|
||||
return ClaudeClient(spec.model, prompts_dir)
|
||||
|
||||
if "gemini" in lower_id:
|
||||
return GeminiClient(model_name, prompts_dir)
|
||||
return GeminiClient(spec.model, prompts_dir)
|
||||
|
||||
if "deepseek" in lower_id:
|
||||
return DeepSeekClient(model_name, prompts_dir)
|
||||
return DeepSeekClient(spec.model, prompts_dir)
|
||||
|
||||
# Default: OpenAI-compatible async client
|
||||
return OpenAIClient(
|
||||
model_name = spec.model,
|
||||
prompts_dir = prompts_dir,
|
||||
base_url = spec.base,
|
||||
api_key = inline_key,
|
||||
)
|
||||
|
||||
# Default to OpenAI-compatible async client
|
||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -296,9 +296,11 @@ async def initialize_new_game(
|
|||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) == len(powers_order):
|
||||
game.power_model_map = dict(zip(powers_order, provided_models))
|
||||
elif len(provided_models) == 1:
|
||||
game.power_model_map = dict(zip(powers_order, provided_models * 7))
|
||||
else:
|
||||
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Using defaults.")
|
||||
game.power_model_map = assign_models_to_powers()
|
||||
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.")
|
||||
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.")
|
||||
else:
|
||||
game.power_model_map = assign_models_to_powers()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue