Diplomacy trainer env (#227)

* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* Add diplomacy environment with AI_Diplomacy submodule

- Add diplomacy_env_minimal.py for diplomacy game environment
- Add atropos_client_minimal.py for client interface
- Add diplomacy_local_server.py for local game server
- Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy
- Fix import ordering and remove unused imports

* test file working, moving to cluster to test training

* updated gitignore

* removed logs

* minor fixes, training running now

* readded proxy reg and queue system

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* queue gameid bug, refactored

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleaned up configs & allowed for openrouter models to be easily used

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* linting

* Remove duplicate dependencies from diplomacy requirements.txt

Only keep AI_Diplomacy-specific dependencies that aren't already in the main project

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
shannonsands 2025-08-12 09:02:16 +10:00 committed by GitHub
parent 4fe67e698d
commit 46f0602227
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1317 additions and 4 deletions

@ -0,0 +1 @@
Subproject commit 70d4ae2fe029ed38bf823ce70f5d45a7bcc7afbc

View file

@ -0,0 +1,54 @@
# Minimal Diplomacy Environment
A simplified Diplomacy RL training environment for Atropos that integrates with AI_Diplomacy.
## Overview
This minimal implementation provides:
- Basic game integration via AI_Diplomacy submodule
- Parallel rollouts with configurable group_size
- LLM request interception through AtroposClient proxy
- Simple supply center based scoring
- No complex features (no GRPO, memory systems, or advanced scoring)
## Architecture
```
Atropos Policy Server
AtroposClientMinimal (proxy)
AI_Diplomacy Game Engine
Game Execution
```
## Quick Start
1. Install dependencies:
```bash
pip install -r requirements.txt
cd AI_Diplomacy
pip install -e .
```
2. Start your Atropos policy server on port 8000
3. Run the environment:
```bash
python diplomacy_env_minimal.py serve
```
## Configuration
Key settings in `DiplomacyEnvMinimalConfig`:
- `max_game_turns`: Number of game turns (default: 10)
- `training_power`: Which power the RL agent controls (default: "FRANCE")
- `group_size`: Number of parallel games per trajectory (default: 4)
## How It Works
1. **Parallel Rollouts**: Each training step runs `group_size` games with the same initial seed
2. **LLM Interception**: AtroposClientMinimal intercepts all LLM calls from AI_Diplomacy
3. **Trajectory Collection**: Game interactions are collected and scored
4. **Best Selection**: The highest scoring trajectory is returned for training

View file

@ -0,0 +1,7 @@
"""
Minimal Diplomacy Environment for Atropos RL Training
"""
from .diplomacy_env_minimal import DiplomacyEnvMinimal, DiplomacyEnvMinimalConfig
__all__ = ["DiplomacyEnvMinimal", "DiplomacyEnvMinimalConfig"]

View file

@ -0,0 +1,268 @@
"""
This is a queue-based proxy that:
- Intercepts LLM requests from AI_Diplomacy
- Puts them on a queue for the environment to process
- Waits for responses from the environment
- Returns responses to AI_Diplomacy
"""
import asyncio
import contextvars
import json
import logging
import os
import sys
import uuid
from typing import Dict, List, Optional
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import (
clients,
)
from environments.game_environments.diplomacy_environment.queue_manager import (
PolicyRequest,
QueueManager,
get_queue_manager,
)
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402
BaseModelClient,
)
logger = logging.getLogger(__name__)
current_game_context = contextvars.ContextVar("current_game_id", default=None)
_game_interactions = {}
class AtroposClientMinimal(BaseModelClient):
"""
Queue-based proxy client that forwards LLM requests through queues.
"""
def __init__(
self,
model_name: str,
queue_manager: Optional[QueueManager] = None,
):
super().__init__(model_name)
self.game_id = current_game_context.get()
if not self.game_id:
raise ValueError("AtroposClientMinimal created without game context set")
self.queue_manager = queue_manager or get_queue_manager()
self.interactions: List[Dict] = []
self.current_power: Optional[str] = None
self.current_phase: Optional[str] = None
logger.info(
f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}"
)
async def generate_response(self, prompt: str, temperature: float = 0.0) -> str:
"""
Put request on queue and wait for response from environment.
This is the main method AI_Diplomacy calls for all LLM interactions.
"""
task_type = self._infer_task_type(prompt)
power = self._extract_power(prompt)
phase = self._extract_phase(prompt)
if power:
self.current_power = power
if phase:
self.current_phase = phase
logger.debug(f"Generating response for {self.current_power}: {task_type}")
try:
request_id = str(uuid.uuid4())
request = PolicyRequest(
request_id=request_id,
game_id=self.game_id,
power=self.current_power or "UNKNOWN",
phase=self.current_phase or "UNKNOWN",
prompt=prompt,
temperature=temperature,
trajectory=self.interactions.copy(),
)
await self.queue_manager.put_request(self.game_id, request)
logger.debug(f"Put request {request_id} on queue for game {self.game_id}")
response = await self.queue_manager.get_response(self.game_id)
if response.request_id != request_id:
logger.warning(
f"Response ID mismatch: expected {request_id}, got {response.request_id}"
)
response_text = response.response
# Track interaction
interaction = {
"power": self.current_power,
"phase": self.current_phase,
"task_type": task_type,
"prompt": prompt,
"response": response_text,
"metadata": response.metadata, # Store any additional info from environment
}
self.interactions.append(interaction)
if self.game_id not in _game_interactions:
_game_interactions[self.game_id] = []
_game_interactions[self.game_id].append(interaction)
return response_text
except asyncio.TimeoutError:
logger.error("Timeout waiting for response from environment")
return self._generate_fallback_response(prompt)
except Exception as e:
logger.error(f"Error generating response: {e}")
return self._generate_fallback_response(prompt)
def _infer_task_type(self, prompt: str) -> str:
"""Infer the type of task from the prompt."""
prompt_lower = prompt.lower()
if "orders" in prompt_lower or "submit" in prompt_lower:
return "orders"
elif "message" in prompt_lower or "negotiate" in prompt_lower:
return "negotiation"
elif "plan" in prompt_lower or "strategy" in prompt_lower:
return "planning"
else:
return "general"
def _extract_power(self, prompt: str) -> Optional[str]:
"""Extract power name from prompt if mentioned."""
for power in [
"AUSTRIA",
"ENGLAND",
"FRANCE",
"GERMANY",
"ITALY",
"RUSSIA",
"TURKEY",
]:
if power in prompt.upper():
return power
return None
def _extract_phase(self, prompt: str) -> Optional[str]:
"""Extract game phase from prompt if mentioned."""
import re
phase_match = re.search(r"[SF]\d{4}[MRB]", prompt)
if phase_match:
return phase_match.group()
verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt)
if verbose_match:
return verbose_match.group()
return None
def _generate_fallback_response(self, prompt: str) -> str:
"""Generate a simple fallback response if there's an issue."""
task_type = self._infer_task_type(prompt)
if task_type == "orders":
return json.dumps(
{
"orders": {},
"explanations": {"general": "Fallback - no server connected"},
}
)
elif task_type == "negotiation":
return json.dumps(
{
"messages": [],
"explanations": {"general": "Fallback - no server connected"},
}
)
else:
return "Fallback response - server not available"
def get_interactions(self) -> List[Dict]:
"""Get all tracked interactions for trajectory collection."""
return self.interactions
def clear_interactions(self):
"""Clear tracked interactions for a new game."""
self.interactions = []
self.current_power = None
self.current_phase = None
def get_game_interactions(game_id: str) -> List[Dict]:
"""Get all interactions for a specific game."""
return _game_interactions.get(game_id, [])
def clear_game_interactions(game_id: str):
"""Clear interactions for a specific game."""
if game_id in _game_interactions:
del _game_interactions[game_id]
def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None):
"""
Register AtroposClientMinimal with AI_Diplomacy's model loading system globally.
This should be called ONCE during environment setup.
Args:
queue_manager: Optional queue manager (uses global if not provided)
"""
if hasattr(clients, "_atropos_registered"):
logger.info("AtroposClientMinimal already registered globally")
return
clients._original_load_model_client = clients.load_model_client
clients._atropos_queue_manager = queue_manager or get_queue_manager()
def load_model_client_with_atropos(
model_id: str, prompts_dir: Optional[str] = None
) -> BaseModelClient:
if model_id.startswith("atropos-"):
logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}")
return AtroposClientMinimal(model_id, clients._atropos_queue_manager)
else:
logger.info(f"Falling back to original loader for {model_id}")
return clients._original_load_model_client(model_id, prompts_dir)
clients.load_model_client = load_model_client_with_atropos
clients._atropos_registered = True
logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy")
if __name__ == "__main__":
async def test_client():
client = AtroposClientMinimal(
"atropos-test",
{"base_url": "http://localhost:8000", "model_name": "test-model"},
)
test_prompts = [
"You are FRANCE. What are your orders for Spring 1901?",
"Send a message to ENGLAND about cooperation.",
"What is your strategic plan?",
]
for prompt in test_prompts:
print(f"\nPrompt: {prompt[:50]}...")
response = await client.generate_response(prompt)
print(f"Response: {response[:100]}...")
print(f"\nTracked {len(client.get_interactions())} interactions")
await client.close()
asyncio.run(test_client())

View file

@ -0,0 +1,636 @@
"""
Minimal Diplomacy Environment for Atropos
A simplified implementation focusing on:
- Basic game integration with AI_Diplomacy
- Parallel rollouts with group_size
- LLM proxy interception via AtroposClient
- Simple supply center based scoring
"""
import asyncio
import json
import logging
import os
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
ScoredDataItem,
)
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from environments.game_environments.diplomacy_environment.atropos_client_minimal import (
clear_game_interactions,
current_game_context,
get_game_interactions,
register_atropos_models_globally,
)
from environments.game_environments.diplomacy_environment.queue_manager import (
PolicyRequest,
PolicyResponse,
get_queue_manager,
)
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
import lm_game # noqa: E402
logger = logging.getLogger(__name__)
POWERS = ["AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"]
STARTING_SUPPLY_CENTERS = {
"AUSTRIA": 3,
"ENGLAND": 3,
"FRANCE": 3,
"GERMANY": 3,
"ITALY": 3,
"RUSSIA": 4,
"TURKEY": 3,
}
class DiplomacyEnvMinimalConfig(BaseEnvConfig):
"""Configuration for the minimal Diplomacy environment."""
env_name: str = "diplomacy_minimal"
# Game settings (minimal implementation, no randomisation stuff)
max_game_turns: int = 10
training_power: str = "FRANCE"
total_steps: int = 10
# Scoring
supply_center_weight: float = 1.0
survival_bonus: float = 0.1
win_bonus: float = 5.0
# Process management
diplomacy_server_port: int = 8432
start_diplomacy_server: bool = True
# Logging
save_game_logs: bool = True
game_logs_dir: str = "./game_logs"
# Evaluation
eval_episodes: int = 10
# Opponent models (None = use SGLang endpoints from server_configs)
opponent_models: Optional[List[str]] = None
class DiplomacyEnvMinimal(BaseEnv):
name = "diplomacy_minimal"
env_config_cls = DiplomacyEnvMinimalConfig
def __init__(
self,
config: DiplomacyEnvMinimalConfig,
server_configs: List[APIServerConfig],
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: DiplomacyEnvMinimalConfig = config
self.game_server_process: Optional[subprocess.Popen] = None
self.game_outcomes_buffer: List[Dict] = []
self.eval_metrics_custom: List[Tuple[str, float]] = []
self.queue_manager = get_queue_manager()
self.active_games: Dict[str, Dict] = {}
if config.save_game_logs:
Path(config.game_logs_dir).mkdir(exist_ok=True)
self.system_prompt = (
f"You are playing Diplomacy as {config.training_power}. "
"Analyze the game state and respond with your strategy and orders."
)
if config.opponent_models:
self.opponent_models = config.opponent_models
else:
self.opponent_models = []
for server_config in server_configs:
model_spec = f"openai:{server_config.model_name}@{server_config.base_url}#{server_config.api_key}"
self.opponent_models.append(model_spec)
@classmethod
def config_init(cls) -> Tuple[DiplomacyEnvMinimalConfig, List[APIServerConfig]]:
"""Initialize default configuration."""
env_config = DiplomacyEnvMinimalConfig(
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
group_size=4,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=4096,
wandb_name=cls.name,
steps_per_eval=20,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9005/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9006/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9007/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
async def setup(self):
"""Set up the environment."""
logger.info(f"Setting up {self.name} environment")
register_atropos_models_globally(self.queue_manager)
if self.config.start_diplomacy_server:
await self._start_diplomacy_server()
asyncio.create_task(self._poll_request_queues())
async def _poll_request_queues(self):
"""Poll request queues and handle policy requests."""
while True:
try:
for game_id in list(self.active_games.keys()):
queue_pair = self.queue_manager.get_queue_pair(game_id)
if not queue_pair:
continue
try:
request = queue_pair.request_queue.get_nowait()
asyncio.create_task(self._handle_policy_request(request))
except asyncio.QueueEmpty:
pass
await asyncio.sleep(0.01)
except Exception as e:
logger.error(f"Error in queue polling: {e}")
await asyncio.sleep(1)
async def _handle_policy_request(self, request: PolicyRequest):
"""Handle a single policy request by sampling from SGLang."""
try:
logger.info(
f"Handling request {request.request_id} for {request.power} in game {request.game_id}"
)
messages = [{"role": "system", "content": self.system_prompt}]
for interaction in request.trajectory:
messages.append({"role": "user", "content": interaction["prompt"]})
messages.append(
{"role": "assistant", "content": interaction["response"]}
)
messages.append({"role": "user", "content": request.prompt})
async with self.server.dedicated_server() as server:
response = await server.chat_completion(
messages=messages,
n=1,
temperature=request.temperature,
max_tokens=2000,
)
response_text = response.choices[0].message.content.strip()
policy_response = PolicyResponse(
request_id=request.request_id,
response=response_text,
metadata={
"power": request.power,
"phase": request.phase,
},
)
await self.queue_manager.put_response(request.game_id, policy_response)
logger.debug(f"Sent response for request {request.request_id}")
except Exception as e:
logger.error(f"Error handling policy request: {e}")
error_response = PolicyResponse(
request_id=request.request_id,
response="Error: Failed to generate response",
metadata={"error": str(e)},
)
await self.queue_manager.put_response(request.game_id, error_response)
async def _start_diplomacy_server(self):
"""Start the AI_Diplomacy game server."""
try:
logger.info(
f"Starting Diplomacy server on port {self.config.diplomacy_server_port}"
)
self.game_server_process = subprocess.Popen(
[
sys.executable,
"-m",
"diplomacy.server.run",
"--port",
str(self.config.diplomacy_server_port),
],
cwd=os.path.join(os.path.dirname(__file__), "AI_Diplomacy"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Give server time to start
await asyncio.sleep(3)
logger.info(
f"Diplomacy server started at http://localhost:{self.config.diplomacy_server_port}"
)
except Exception as e:
logger.error(f"Failed to start Diplomacy server: {e}")
async def collect_trajectories(
self, item: Item
) -> Tuple[ScoredDataGroup, List[Item]]:
"""
Run parallel Diplomacy games and collect all trajectories.
This implements the key RL training pattern:
1. Run group_size parallel games with the same seed
2. Each game explores different action sequences
3. Score each trajectory based on game outcome
4. Return all trajectories as a ScoredDataGroup for training
"""
logger.warning(
f"[DiplomacyEnvMinimal] collect_trajectories called with item: {item}"
)
base_game_id = item.get("game_id", f"game-{int(time.time())}")
seed = item.get("seed", random.randint(0, 1_000_000))
logger.info(
f"Starting {self.config.group_size} parallel games with seed {seed}"
)
# Run parallel games w/ same seed
game_tasks = []
for i in range(self.config.group_size):
game_id = f"{base_game_id}-{i}"
task = self._run_single_game(game_id, seed, trajectory_id=i)
game_tasks.append(task)
results = await asyncio.gather(*game_tasks, return_exceptions=True)
scored_items = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"Game {i} failed: {result}")
continue
if result and result[0]:
scored_items.append(result[0])
logger.warning(
f"[DiplomacyEnvMinimal] Collected {len(scored_items)} scored items"
)
if not scored_items:
logger.error("No valid trajectories collected")
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
),
[],
)
sdg = ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
)
for scored_item in scored_items:
sdg["tokens"].append(scored_item["tokens"])
sdg["masks"].append(scored_item["masks"])
sdg["scores"].append(scored_item["scores"])
if self.config.include_messages and scored_item.get("messages"):
sdg["messages"].append(scored_item["messages"])
logger.info(f"Collected {len(scored_items)} trajectories")
logger.warning(
f"[DiplomacyEnvMinimal] Returning ScoredDataGroup with {len(sdg['tokens'])} "
f"tokens, {len(sdg['scores'])} scores"
)
logger.warning(
f"[DiplomacyEnvMinimal] First few scores: {sdg['scores'][:5] if sdg['scores'] else 'None'}"
)
for i in range(self.config.group_size):
game_id = f"{base_game_id}-{i}"
if game_id in self.active_games:
del self.active_games[game_id]
try:
await self.queue_manager.remove_game_queues(game_id)
except Exception as e:
logger.debug(f"Error cleaning up queues for {game_id}: {e}")
return sdg, []
async def _run_single_game(
self, game_id: str, seed: int, trajectory_id: int
) -> Tuple[Optional[ScoredDataItem], None]:
"""
Run a single Diplomacy game and return scored trajectory.
"""
try:
queue_pair = await self.queue_manager.create_game_queues(game_id)
self.active_games[game_id] = {
"queue_pair": queue_pair,
"start_time": time.time(),
"interactions": [],
}
token = current_game_context.set(game_id)
try:
game_result = await self._run_diplomacy_game(
game_id, seed, trajectory_id
)
finally:
current_game_context.reset(token)
if not game_result:
logger.error(f"Game {game_id} failed to complete")
return None, None
score = self._calculate_score(game_result, self.config.training_power)
interactions = get_game_interactions(game_id)
training_interactions = [
i for i in interactions if i.get("power") == self.config.training_power
]
if training_interactions:
messages = [{"role": "system", "content": self.system_prompt}]
for interaction in training_interactions:
messages.append({"role": "user", "content": interaction["prompt"]})
messages.append(
{"role": "assistant", "content": interaction["response"]}
)
logger.info(
f"Collected {len(training_interactions)} interactions for {self.config.training_power}"
)
else:
logger.warning(
f"No interactions found for {self.config.training_power} in game {game_id}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"Playing Diplomacy game {game_id}"},
{
"role": "assistant",
"content": f"Game completed with score {score:.2f}",
},
]
clear_game_interactions(game_id)
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True,
)
scored_data_item = ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokenization_result["tokens"],
masks=tokenization_result["masks"],
scores=score,
)
self.game_outcomes_buffer.append(
{
"game_id": game_id,
"score": score,
"winner": game_result.get("winner"),
"turns": game_result.get("turns_played", 0),
"final_centers": game_result.get("final_centers", {}),
}
)
return scored_data_item, None
except Exception as e:
logger.error(f"Error in game {game_id}: {e}", exc_info=True)
return None, None
async def _run_diplomacy_game(
self, game_id: str, seed: int, trajectory_id: int = 0
) -> Optional[Dict]:
"""Run a Diplomacy game using AI_Diplomacy's lm_game module."""
game_output_path = os.path.join(self.config.game_logs_dir, f"{game_id}.json")
models = []
opponent_idx = 0
for power in POWERS:
if power == self.config.training_power:
models.append("atropos-training-policy")
else:
models.append(
self.opponent_models[opponent_idx % len(self.opponent_models)]
)
opponent_idx += 1
original_argv = sys.argv
try:
sys.argv = [
"lm_game.py",
"--models",
",".join(models),
"--max_year",
str(1900 + self.config.max_game_turns),
"--output",
game_output_path,
"--seed",
str(seed + trajectory_id),
"--num_negotiation_rounds",
"0",
]
await lm_game.main()
actual_game_file = os.path.join(game_output_path, "lmvsgame.json")
if os.path.exists(actual_game_file):
with open(actual_game_file, "r") as f:
saved_game = json.load(f)
phases = saved_game.get("phases", [])
last_phase = phases[-1] if phases else {}
result = {
"winner": saved_game.get("winner"),
"turns_played": len(phases),
"final_centers": {},
}
for power in POWERS:
centers = (
last_phase.get("state", {}).get("centers", {}).get(power, [])
)
result["final_centers"][power] = len(centers)
return result
else:
logger.error(f"Game output not found: {actual_game_file}")
return None
finally:
sys.argv = original_argv
def _calculate_score(self, game_result: Dict, power: str) -> float:
"""Calculate score for a power based on game outcome."""
start_centers = STARTING_SUPPLY_CENTERS[power]
end_centers = game_result["final_centers"].get(power, 0)
center_score = (end_centers - start_centers) * self.config.supply_center_weight
survival_score = self.config.survival_bonus * game_result["turns_played"]
win_score = self.config.win_bonus if game_result["winner"] == power else 0.0
total_score = center_score + survival_score + win_score
logger.info(
f"{power} score: centers={center_score:.2f}, "
f"survival={survival_score:.2f}, win={win_score:.2f}, "
f"total={total_score:.2f}"
)
return total_score
async def get_next_item(self) -> Item:
"""Generate configuration for the next game."""
return {
"game_id": f"game-{int(time.time())}-{random.randint(1000, 9999)}",
"seed": random.randint(0, 1_000_000),
}
async def evaluate(self, *args, **kwargs):
"""Run evaluation games."""
logger.info(f"Starting evaluation with {self.config.eval_episodes} episodes")
eval_scores = []
wins = 0
for i in range(self.config.eval_episodes):
item = await self.get_next_item()
item["is_eval"] = True
scored_data_group, _ = await self.collect_trajectories(item)
if scored_data_group and scored_data_group["scores"]:
avg_score = sum(scored_data_group["scores"]) / len(
scored_data_group["scores"]
)
eval_scores.append(avg_score)
if (
self.game_outcomes_buffer
and self.game_outcomes_buffer[-1]["winner"]
== self.config.training_power
):
wins += 1
if eval_scores:
avg_score = sum(eval_scores) / len(eval_scores)
win_rate = wins / self.config.eval_episodes
self.eval_metrics_custom = [
(f"{self.name}_eval/avg_score", avg_score),
(f"{self.name}_eval/win_rate", win_rate),
(f"{self.name}_eval/num_completed", len(eval_scores)),
]
logger.info(
f"Evaluation completed: avg_score={avg_score:.2f}, win_rate={win_rate:.2%}"
)
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
if self.game_outcomes_buffer:
total_games = len(self.game_outcomes_buffer)
avg_score = sum(g["score"] for g in self.game_outcomes_buffer) / total_games
wins = sum(
1
for g in self.game_outcomes_buffer
if g["winner"] == self.config.training_power
)
win_rate = wins / total_games
avg_turns = sum(g["turns"] for g in self.game_outcomes_buffer) / total_games
wandb_metrics.update(
{
f"{self.name}_train/avg_score": avg_score,
f"{self.name}_train/win_rate": win_rate,
f"{self.name}_train/avg_turns": avg_turns,
f"{self.name}_train/num_games": total_games,
}
)
self.game_outcomes_buffer = []
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
def __del__(self):
"""Clean up server process on exit."""
if self.game_server_process:
self.game_server_process.terminate()
self.game_server_process.wait()
if __name__ == "__main__":
DiplomacyEnvMinimal.cli()

View file

@ -0,0 +1,218 @@
#!/usr/bin/env python3
"""
Local test server for the minimal Diplomacy environment.
This script runs the full AI_Diplomacy game with real OpenAI models
to test the AtroposClient proxy integration.
"""
import asyncio
import logging
import os
from dotenv import load_dotenv
from atroposlib.envs.base import APIServerConfig, EvalHandlingEnum
from environments.game_environments.diplomacy_environment.diplomacy_env_minimal import (
DiplomacyEnvMinimal,
DiplomacyEnvMinimalConfig,
)
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main():
"""Run Diplomacy games for testing the minimal environment."""
logger.info("Starting Diplomacy minimal environment local test runner")
# Check for OpenRouter API key
if not os.getenv("OPENROUTER_API_KEY"):
logger.error(
"OPENROUTER_API_KEY not found. Please set it in your environment or .env file"
)
return
# Configure environment - using OpenRouter model
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
openrouter_model = (
f"openai:openai/gpt-oss-120b@https://openrouter.ai/api/v1#{openrouter_api_key}"
)
# Create list of opponent models (6 powers besides training power)
opponent_models = [openrouter_model] * 6
env_config = DiplomacyEnvMinimalConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=2, # Run 2 parallel games
use_wandb=False,
wandb_name="diplomacy_minimal_local_test",
max_num_workers=1,
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1,
steps_per_eval=0,
max_token_length=4096,
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.NONE,
eval_limit_ratio=0.0,
max_game_turns=5, # Short games for testing
training_power="FRANCE", # Which power we're training
include_messages=True, # Include messages for debugging
eval_episodes=0,
start_diplomacy_server=True, # Let the env start the server
save_game_logs=True,
game_logs_dir="./test_game_logs",
opponent_models=opponent_models, # Use OpenRouter for all opponents
)
# Configure server - using 4 servers to match SLURM setup
# For local testing, we'll simulate this with the same OpenRouter endpoint
server_configs = [
APIServerConfig(
model_name="openai/gpt-oss-120b", # Using the OpenRouter model
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
num_requests_for_eval=0,
),
APIServerConfig(
model_name="openai/gpt-oss-120b",
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
num_requests_for_eval=0,
),
APIServerConfig(
model_name="openai/gpt-oss-120b",
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
num_requests_for_eval=0,
),
APIServerConfig(
model_name="openai/gpt-oss-120b",
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
num_requests_for_eval=0,
),
]
logger.info("Using OpenRouter openai/gpt-oss-120b for Diplomacy test")
logger.debug(f"Env Config: {env_config}")
logger.debug(f"Server Configs: {server_configs}")
try:
env = DiplomacyEnvMinimal(
config=env_config,
server_configs=server_configs,
slurm=False,
testing=False,
)
except Exception as e:
logger.exception(f"Failed to initialize DiplomacyEnvMinimal: {e}")
return
logger.info("Running test games")
try:
await env.setup()
# Get number of episodes from command line or default
import sys
num_episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 3
# Track statistics
episode_results = []
for episode_num in range(num_episodes):
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
item = await env.get_next_item()
logger.info(f"Game ID: {item['game_id']}, Seed: {item['seed']}")
# Collect trajectories (will run group_size parallel games)
scored_data_group, _ = await env.collect_trajectories(item)
if scored_data_group and scored_data_group["scores"]:
avg_score = sum(scored_data_group["scores"]) / len(
scored_data_group["scores"]
)
logger.info(
f"Collected {len(scored_data_group['scores'])} trajectories with average score: {avg_score:.2f}"
)
# Get game outcomes from buffer
if env.game_outcomes_buffer:
latest_outcomes = env.game_outcomes_buffer[-env.config.group_size :]
for i, outcome in enumerate(latest_outcomes):
logger.info(
f" Game {i}: Score={outcome['score']:.2f}, "
f"Winner={outcome['winner']}, "
f"Turns={outcome['turns']}, "
f"Centers={outcome['final_centers'].get(env.config.training_power, 0)}"
)
episode_results.append(
{
"episode": episode_num + 1,
"score": avg_score,
"outcomes": latest_outcomes if env.game_outcomes_buffer else [],
}
)
else:
logger.error("Failed to collect trajectory")
episode_results.append(
{
"episode": episode_num + 1,
"score": 0.0,
"outcomes": [],
}
)
# Print overall statistics
logger.info("\n" + "=" * 60)
logger.info("OVERALL RESULTS SUMMARY")
logger.info("=" * 60)
logger.info(f"Total episodes: {num_episodes}")
logger.info(f"Group size: {env.config.group_size} games per episode")
logger.info(f"Training power: {env.config.training_power}")
# Calculate statistics
if episode_results:
avg_score = sum(ep["score"] for ep in episode_results) / len(
episode_results
)
logger.info(f"\nAverage trajectory score: {avg_score:.2f}")
# Count wins
total_games = 0
wins = 0
for ep in episode_results:
for outcome in ep["outcomes"]:
total_games += 1
if outcome["winner"] == env.config.training_power:
wins += 1
if total_games > 0:
logger.info(
f"Win rate: {wins}/{total_games} ({100*wins/total_games:.1f}%)"
)
# Average supply centers
total_centers = sum(
outcome["final_centers"].get(env.config.training_power, 0)
for ep in episode_results
for outcome in ep["outcomes"]
)
avg_centers = total_centers / total_games
logger.info(f"Average final supply centers: {avg_centers:.1f}")
logger.info("\nTest completed successfully!")
except Exception as e:
logger.exception(f"An error occurred during test: {e}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,117 @@
"""
Queue Manager for Diplomacy Environment
Manages request/response queues between AtroposClient proxies and the environment.
Each game gets its own queue pair for isolation.
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import Dict, Optional
logger = logging.getLogger(__name__)
@dataclass
class QueuePair:
"""A pair of queues for a single game."""
game_id: str
request_queue: asyncio.Queue
response_queue: asyncio.Queue
@dataclass
class PolicyRequest:
"""Request from proxy to environment for policy sampling."""
request_id: str
game_id: str
power: str
phase: str
prompt: str
temperature: float
trajectory: list
@dataclass
class PolicyResponse:
"""Response from environment back to proxy."""
request_id: str
response: str
metadata: dict
class QueueManager:
"""Manages queues for all parallel games."""
def __init__(self):
self.queue_pairs: Dict[str, QueuePair] = {}
self._lock = asyncio.Lock()
async def create_game_queues(self, game_id: str) -> QueuePair:
"""Create a new queue pair for a game."""
async with self._lock:
if game_id in self.queue_pairs:
logger.warning(f"Queue pair already exists for game {game_id}")
return self.queue_pairs[game_id]
queue_pair = QueuePair(
game_id=game_id,
request_queue=asyncio.Queue(),
response_queue=asyncio.Queue(),
)
self.queue_pairs[game_id] = queue_pair
logger.info(f"Created queue pair for game {game_id}")
return queue_pair
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
"""Get queue pair for a game."""
return self.queue_pairs.get(game_id)
async def remove_game_queues(self, game_id: str):
"""Remove queues for a completed game."""
async with self._lock:
if game_id in self.queue_pairs:
del self.queue_pairs[game_id]
logger.info(f"Removed queue pair for game {game_id}")
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
"""Get all request queues for polling."""
return {
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
}
async def put_request(self, game_id: str, request: PolicyRequest):
"""Put a request on the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
await queue_pair.request_queue.put(request)
else:
raise ValueError(f"No queue pair found for game {game_id}")
async def get_response(self, game_id: str) -> PolicyResponse:
"""Get a response from the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
return await queue_pair.response_queue.get()
else:
raise ValueError(f"No queue pair found for game {game_id}")
async def put_response(self, game_id: str, response: PolicyResponse):
"""Put a response on the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
await queue_pair.response_queue.put(response)
else:
raise ValueError(f"No queue pair found for game {game_id}")
_queue_manager = QueueManager()
def get_queue_manager() -> QueueManager:
"""Get the global queue manager instance."""
return _queue_manager

View file

@ -0,0 +1,8 @@
# Requirements for minimal Diplomacy environment
# Only includes dependencies specific to AI_Diplomacy that aren't in main project
# AI_Diplomacy specific dependencies
tornado>=6.0
ujson>=5.0
websockets>=10.0
coloredlogs>=15.0.1

View file

@ -49,7 +49,7 @@ async def main():
# Configure server - using same model as blackjack example
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano",
model_name="gpt-4.1",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY"),
num_requests_for_eval=0,
@ -99,7 +99,6 @@ async def main():
item = await env.get_next_item()
challenge_name = item["challenge_name"]
challenge_counts[challenge_name] += 1
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
logger.info(f"Using game: {item}")

View file

@ -44,7 +44,6 @@ class TextWorldChallengeRegistry:
def __init__(self, seed: Optional[int] = None):
self._challenges = self.CHALLENGES.copy()
self.rng = random.Random(seed)
# Cache for all possible combinations
self._all_combinations = None
self._combination_index = 0