mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4fe67e698d
commit
46f0602227
13 changed files with 1317 additions and 4 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -205,3 +205,6 @@ environments/community/word_hunt/word_hunt_rollouts*.html
|
||||||
*.z8
|
*.z8
|
||||||
*.ni
|
*.ni
|
||||||
*.z5
|
*.z5
|
||||||
|
|
||||||
|
# Diplomacy artefacts
|
||||||
|
environments/game_environments/diplomacy_environment/logs/
|
||||||
|
|
|
||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -4,3 +4,6 @@
|
||||||
[submodule "environments/reasoning_gym_environment/reasoning-gym"]
|
[submodule "environments/reasoning_gym_environment/reasoning-gym"]
|
||||||
path = environments/reasoning_gym_environment/reasoning-gym
|
path = environments/reasoning_gym_environment/reasoning-gym
|
||||||
url = https://github.com/open-thought/reasoning-gym
|
url = https://github.com/open-thought/reasoning-gym
|
||||||
|
[submodule "environments/game_environments/diplomacy_environment/AI_Diplomacy"]
|
||||||
|
path = environments/game_environments/diplomacy_environment/AI_Diplomacy
|
||||||
|
url = https://github.com/GoodStartLabs/AI_Diplomacy.git
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 70d4ae2fe029ed38bf823ce70f5d45a7bcc7afbc
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Minimal Diplomacy Environment
|
||||||
|
|
||||||
|
A simplified Diplomacy RL training environment for Atropos that integrates with AI_Diplomacy.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This minimal implementation provides:
|
||||||
|
- Basic game integration via AI_Diplomacy submodule
|
||||||
|
- Parallel rollouts with configurable group_size
|
||||||
|
- LLM request interception through AtroposClient proxy
|
||||||
|
- Simple supply center based scoring
|
||||||
|
- No complex features (no GRPO, memory systems, or advanced scoring)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Atropos Policy Server
|
||||||
|
↓
|
||||||
|
AtroposClientMinimal (proxy)
|
||||||
|
↓
|
||||||
|
AI_Diplomacy Game Engine
|
||||||
|
↓
|
||||||
|
Game Execution
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cd AI_Diplomacy
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start your Atropos policy server on port 8000
|
||||||
|
|
||||||
|
3. Run the environment:
|
||||||
|
```bash
|
||||||
|
python diplomacy_env_minimal.py serve
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Key settings in `DiplomacyEnvMinimalConfig`:
|
||||||
|
- `max_game_turns`: Number of game turns (default: 10)
|
||||||
|
- `training_power`: Which power the RL agent controls (default: "FRANCE")
|
||||||
|
- `group_size`: Number of parallel games per trajectory (default: 4)
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **Parallel Rollouts**: Each training step runs `group_size` games with the same initial seed
|
||||||
|
2. **LLM Interception**: AtroposClientMinimal intercepts all LLM calls from AI_Diplomacy
|
||||||
|
3. **Trajectory Collection**: Game interactions are collected and scored
|
||||||
|
4. **Best Selection**: The highest scoring trajectory is returned for training
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""
|
||||||
|
Minimal Diplomacy Environment for Atropos RL Training
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .diplomacy_env_minimal import DiplomacyEnvMinimal, DiplomacyEnvMinimalConfig
|
||||||
|
|
||||||
|
__all__ = ["DiplomacyEnvMinimal", "DiplomacyEnvMinimalConfig"]
|
||||||
|
|
@ -0,0 +1,268 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
This is a queue-based proxy that:
|
||||||
|
- Intercepts LLM requests from AI_Diplomacy
|
||||||
|
- Puts them on a queue for the environment to process
|
||||||
|
- Waits for responses from the environment
|
||||||
|
- Returns responses to AI_Diplomacy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextvars
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import (
|
||||||
|
clients,
|
||||||
|
)
|
||||||
|
from environments.game_environments.diplomacy_environment.queue_manager import (
|
||||||
|
PolicyRequest,
|
||||||
|
QueueManager,
|
||||||
|
get_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
|
||||||
|
|
||||||
|
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402
|
||||||
|
BaseModelClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_game_context = contextvars.ContextVar("current_game_id", default=None)
|
||||||
|
_game_interactions = {}
|
||||||
|
|
||||||
|
|
||||||
|
class AtroposClientMinimal(BaseModelClient):
|
||||||
|
"""
|
||||||
|
Queue-based proxy client that forwards LLM requests through queues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
queue_manager: Optional[QueueManager] = None,
|
||||||
|
):
|
||||||
|
super().__init__(model_name)
|
||||||
|
self.game_id = current_game_context.get()
|
||||||
|
if not self.game_id:
|
||||||
|
raise ValueError("AtroposClientMinimal created without game context set")
|
||||||
|
|
||||||
|
self.queue_manager = queue_manager or get_queue_manager()
|
||||||
|
|
||||||
|
self.interactions: List[Dict] = []
|
||||||
|
self.current_power: Optional[str] = None
|
||||||
|
self.current_phase: Optional[str] = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str, temperature: float = 0.0) -> str:
|
||||||
|
"""
|
||||||
|
Put request on queue and wait for response from environment.
|
||||||
|
This is the main method AI_Diplomacy calls for all LLM interactions.
|
||||||
|
"""
|
||||||
|
task_type = self._infer_task_type(prompt)
|
||||||
|
power = self._extract_power(prompt)
|
||||||
|
phase = self._extract_phase(prompt)
|
||||||
|
|
||||||
|
if power:
|
||||||
|
self.current_power = power
|
||||||
|
if phase:
|
||||||
|
self.current_phase = phase
|
||||||
|
|
||||||
|
logger.debug(f"Generating response for {self.current_power}: {task_type}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
request = PolicyRequest(
|
||||||
|
request_id=request_id,
|
||||||
|
game_id=self.game_id,
|
||||||
|
power=self.current_power or "UNKNOWN",
|
||||||
|
phase=self.current_phase or "UNKNOWN",
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
trajectory=self.interactions.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.queue_manager.put_request(self.game_id, request)
|
||||||
|
logger.debug(f"Put request {request_id} on queue for game {self.game_id}")
|
||||||
|
|
||||||
|
response = await self.queue_manager.get_response(self.game_id)
|
||||||
|
|
||||||
|
if response.request_id != request_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Response ID mismatch: expected {request_id}, got {response.request_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response_text = response.response
|
||||||
|
|
||||||
|
# Track interaction
|
||||||
|
interaction = {
|
||||||
|
"power": self.current_power,
|
||||||
|
"phase": self.current_phase,
|
||||||
|
"task_type": task_type,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": response_text,
|
||||||
|
"metadata": response.metadata, # Store any additional info from environment
|
||||||
|
}
|
||||||
|
self.interactions.append(interaction)
|
||||||
|
|
||||||
|
if self.game_id not in _game_interactions:
|
||||||
|
_game_interactions[self.game_id] = []
|
||||||
|
_game_interactions[self.game_id].append(interaction)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Timeout waiting for response from environment")
|
||||||
|
return self._generate_fallback_response(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating response: {e}")
|
||||||
|
return self._generate_fallback_response(prompt)
|
||||||
|
|
||||||
|
def _infer_task_type(self, prompt: str) -> str:
|
||||||
|
"""Infer the type of task from the prompt."""
|
||||||
|
prompt_lower = prompt.lower()
|
||||||
|
|
||||||
|
if "orders" in prompt_lower or "submit" in prompt_lower:
|
||||||
|
return "orders"
|
||||||
|
elif "message" in prompt_lower or "negotiate" in prompt_lower:
|
||||||
|
return "negotiation"
|
||||||
|
elif "plan" in prompt_lower or "strategy" in prompt_lower:
|
||||||
|
return "planning"
|
||||||
|
else:
|
||||||
|
return "general"
|
||||||
|
|
||||||
|
def _extract_power(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract power name from prompt if mentioned."""
|
||||||
|
for power in [
|
||||||
|
"AUSTRIA",
|
||||||
|
"ENGLAND",
|
||||||
|
"FRANCE",
|
||||||
|
"GERMANY",
|
||||||
|
"ITALY",
|
||||||
|
"RUSSIA",
|
||||||
|
"TURKEY",
|
||||||
|
]:
|
||||||
|
if power in prompt.upper():
|
||||||
|
return power
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_phase(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract game phase from prompt if mentioned."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
phase_match = re.search(r"[SF]\d{4}[MRB]", prompt)
|
||||||
|
if phase_match:
|
||||||
|
return phase_match.group()
|
||||||
|
|
||||||
|
verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt)
|
||||||
|
if verbose_match:
|
||||||
|
return verbose_match.group()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_fallback_response(self, prompt: str) -> str:
|
||||||
|
"""Generate a simple fallback response if there's an issue."""
|
||||||
|
task_type = self._infer_task_type(prompt)
|
||||||
|
|
||||||
|
if task_type == "orders":
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"orders": {},
|
||||||
|
"explanations": {"general": "Fallback - no server connected"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif task_type == "negotiation":
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [],
|
||||||
|
"explanations": {"general": "Fallback - no server connected"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return "Fallback response - server not available"
|
||||||
|
|
||||||
|
def get_interactions(self) -> List[Dict]:
|
||||||
|
"""Get all tracked interactions for trajectory collection."""
|
||||||
|
return self.interactions
|
||||||
|
|
||||||
|
def clear_interactions(self):
|
||||||
|
"""Clear tracked interactions for a new game."""
|
||||||
|
self.interactions = []
|
||||||
|
self.current_power = None
|
||||||
|
self.current_phase = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_game_interactions(game_id: str) -> List[Dict]:
|
||||||
|
"""Get all interactions for a specific game."""
|
||||||
|
return _game_interactions.get(game_id, [])
|
||||||
|
|
||||||
|
|
||||||
|
def clear_game_interactions(game_id: str):
|
||||||
|
"""Clear interactions for a specific game."""
|
||||||
|
if game_id in _game_interactions:
|
||||||
|
del _game_interactions[game_id]
|
||||||
|
|
||||||
|
|
||||||
|
def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None):
|
||||||
|
"""
|
||||||
|
Register AtroposClientMinimal with AI_Diplomacy's model loading system globally.
|
||||||
|
This should be called ONCE during environment setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_manager: Optional queue manager (uses global if not provided)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hasattr(clients, "_atropos_registered"):
|
||||||
|
logger.info("AtroposClientMinimal already registered globally")
|
||||||
|
return
|
||||||
|
|
||||||
|
clients._original_load_model_client = clients.load_model_client
|
||||||
|
clients._atropos_queue_manager = queue_manager or get_queue_manager()
|
||||||
|
|
||||||
|
def load_model_client_with_atropos(
|
||||||
|
model_id: str, prompts_dir: Optional[str] = None
|
||||||
|
) -> BaseModelClient:
|
||||||
|
if model_id.startswith("atropos-"):
|
||||||
|
logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}")
|
||||||
|
return AtroposClientMinimal(model_id, clients._atropos_queue_manager)
|
||||||
|
else:
|
||||||
|
logger.info(f"Falling back to original loader for {model_id}")
|
||||||
|
return clients._original_load_model_client(model_id, prompts_dir)
|
||||||
|
|
||||||
|
clients.load_model_client = load_model_client_with_atropos
|
||||||
|
clients._atropos_registered = True
|
||||||
|
|
||||||
|
logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
async def test_client():
|
||||||
|
client = AtroposClientMinimal(
|
||||||
|
"atropos-test",
|
||||||
|
{"base_url": "http://localhost:8000", "model_name": "test-model"},
|
||||||
|
)
|
||||||
|
|
||||||
|
test_prompts = [
|
||||||
|
"You are FRANCE. What are your orders for Spring 1901?",
|
||||||
|
"Send a message to ENGLAND about cooperation.",
|
||||||
|
"What is your strategic plan?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for prompt in test_prompts:
|
||||||
|
print(f"\nPrompt: {prompt[:50]}...")
|
||||||
|
response = await client.generate_response(prompt)
|
||||||
|
print(f"Response: {response[:100]}...")
|
||||||
|
|
||||||
|
print(f"\nTracked {len(client.get_interactions())} interactions")
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(test_client())
|
||||||
|
|
@ -0,0 +1,636 @@
|
||||||
|
"""
|
||||||
|
Minimal Diplomacy Environment for Atropos
|
||||||
|
|
||||||
|
A simplified implementation focusing on:
|
||||||
|
- Basic game integration with AI_Diplomacy
|
||||||
|
- Parallel rollouts with group_size
|
||||||
|
- LLM proxy interception via AtroposClient
|
||||||
|
- Simple supply center based scoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from atroposlib.envs.base import (
|
||||||
|
APIServerConfig,
|
||||||
|
BaseEnv,
|
||||||
|
BaseEnvConfig,
|
||||||
|
ScoredDataGroup,
|
||||||
|
ScoredDataItem,
|
||||||
|
)
|
||||||
|
from atroposlib.type_definitions import Item
|
||||||
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
from environments.game_environments.diplomacy_environment.atropos_client_minimal import (
|
||||||
|
clear_game_interactions,
|
||||||
|
current_game_context,
|
||||||
|
get_game_interactions,
|
||||||
|
register_atropos_models_globally,
|
||||||
|
)
|
||||||
|
from environments.game_environments.diplomacy_environment.queue_manager import (
|
||||||
|
PolicyRequest,
|
||||||
|
PolicyResponse,
|
||||||
|
get_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
|
||||||
|
import lm_game # noqa: E402
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
POWERS = ["AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"]
|
||||||
|
STARTING_SUPPLY_CENTERS = {
|
||||||
|
"AUSTRIA": 3,
|
||||||
|
"ENGLAND": 3,
|
||||||
|
"FRANCE": 3,
|
||||||
|
"GERMANY": 3,
|
||||||
|
"ITALY": 3,
|
||||||
|
"RUSSIA": 4,
|
||||||
|
"TURKEY": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DiplomacyEnvMinimalConfig(BaseEnvConfig):
|
||||||
|
"""Configuration for the minimal Diplomacy environment."""
|
||||||
|
|
||||||
|
env_name: str = "diplomacy_minimal"
|
||||||
|
|
||||||
|
# Game settings (minimal implementation, no randomisation stuff)
|
||||||
|
max_game_turns: int = 10
|
||||||
|
training_power: str = "FRANCE"
|
||||||
|
total_steps: int = 10
|
||||||
|
|
||||||
|
# Scoring
|
||||||
|
supply_center_weight: float = 1.0
|
||||||
|
survival_bonus: float = 0.1
|
||||||
|
win_bonus: float = 5.0
|
||||||
|
|
||||||
|
# Process management
|
||||||
|
diplomacy_server_port: int = 8432
|
||||||
|
start_diplomacy_server: bool = True
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
save_game_logs: bool = True
|
||||||
|
game_logs_dir: str = "./game_logs"
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
eval_episodes: int = 10
|
||||||
|
|
||||||
|
# Opponent models (None = use SGLang endpoints from server_configs)
|
||||||
|
opponent_models: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DiplomacyEnvMinimal(BaseEnv):
|
||||||
|
name = "diplomacy_minimal"
|
||||||
|
env_config_cls = DiplomacyEnvMinimalConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DiplomacyEnvMinimalConfig,
|
||||||
|
server_configs: List[APIServerConfig],
|
||||||
|
slurm: bool = False,
|
||||||
|
testing: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(config, server_configs, slurm, testing)
|
||||||
|
self.config: DiplomacyEnvMinimalConfig = config
|
||||||
|
self.game_server_process: Optional[subprocess.Popen] = None
|
||||||
|
self.game_outcomes_buffer: List[Dict] = []
|
||||||
|
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
||||||
|
|
||||||
|
self.queue_manager = get_queue_manager()
|
||||||
|
self.active_games: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
if config.save_game_logs:
|
||||||
|
Path(config.game_logs_dir).mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
self.system_prompt = (
|
||||||
|
f"You are playing Diplomacy as {config.training_power}. "
|
||||||
|
"Analyze the game state and respond with your strategy and orders."
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.opponent_models:
|
||||||
|
self.opponent_models = config.opponent_models
|
||||||
|
else:
|
||||||
|
self.opponent_models = []
|
||||||
|
for server_config in server_configs:
|
||||||
|
model_spec = f"openai:{server_config.model_name}@{server_config.base_url}#{server_config.api_key}"
|
||||||
|
self.opponent_models.append(model_spec)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_init(cls) -> Tuple[DiplomacyEnvMinimalConfig, List[APIServerConfig]]:
|
||||||
|
"""Initialize default configuration."""
|
||||||
|
env_config = DiplomacyEnvMinimalConfig(
|
||||||
|
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||||
|
group_size=4,
|
||||||
|
use_wandb=True,
|
||||||
|
rollout_server_url="http://localhost:8000",
|
||||||
|
max_token_length=4096,
|
||||||
|
wandb_name=cls.name,
|
||||||
|
steps_per_eval=20,
|
||||||
|
)
|
||||||
|
server_configs = [
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||||
|
base_url="http://localhost:9004/v1",
|
||||||
|
api_key="x",
|
||||||
|
num_requests_for_eval=128,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||||
|
base_url="http://localhost:9005/v1",
|
||||||
|
api_key="x",
|
||||||
|
num_requests_for_eval=128,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||||
|
base_url="http://localhost:9006/v1",
|
||||||
|
api_key="x",
|
||||||
|
num_requests_for_eval=128,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||||
|
base_url="http://localhost:9007/v1",
|
||||||
|
api_key="x",
|
||||||
|
num_requests_for_eval=128,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return env_config, server_configs
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Set up the environment."""
|
||||||
|
logger.info(f"Setting up {self.name} environment")
|
||||||
|
|
||||||
|
register_atropos_models_globally(self.queue_manager)
|
||||||
|
|
||||||
|
if self.config.start_diplomacy_server:
|
||||||
|
await self._start_diplomacy_server()
|
||||||
|
|
||||||
|
asyncio.create_task(self._poll_request_queues())
|
||||||
|
|
||||||
|
async def _poll_request_queues(self):
|
||||||
|
"""Poll request queues and handle policy requests."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
for game_id in list(self.active_games.keys()):
|
||||||
|
queue_pair = self.queue_manager.get_queue_pair(game_id)
|
||||||
|
if not queue_pair:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = queue_pair.request_queue.get_nowait()
|
||||||
|
asyncio.create_task(self._handle_policy_request(request))
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in queue polling: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def _handle_policy_request(self, request: PolicyRequest):
|
||||||
|
"""Handle a single policy request by sampling from SGLang."""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Handling request {request.request_id} for {request.power} in game {request.game_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
|
for interaction in request.trajectory:
|
||||||
|
messages.append({"role": "user", "content": interaction["prompt"]})
|
||||||
|
messages.append(
|
||||||
|
{"role": "assistant", "content": interaction["response"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": request.prompt})
|
||||||
|
|
||||||
|
async with self.server.dedicated_server() as server:
|
||||||
|
response = await server.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
n=1,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_tokens=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_text = response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
policy_response = PolicyResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
response=response_text,
|
||||||
|
metadata={
|
||||||
|
"power": request.power,
|
||||||
|
"phase": request.phase,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.queue_manager.put_response(request.game_id, policy_response)
|
||||||
|
logger.debug(f"Sent response for request {request.request_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling policy request: {e}")
|
||||||
|
error_response = PolicyResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
response="Error: Failed to generate response",
|
||||||
|
metadata={"error": str(e)},
|
||||||
|
)
|
||||||
|
await self.queue_manager.put_response(request.game_id, error_response)
|
||||||
|
|
||||||
|
async def _start_diplomacy_server(self):
|
||||||
|
"""Start the AI_Diplomacy game server."""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Starting Diplomacy server on port {self.config.diplomacy_server_port}"
|
||||||
|
)
|
||||||
|
self.game_server_process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"diplomacy.server.run",
|
||||||
|
"--port",
|
||||||
|
str(self.config.diplomacy_server_port),
|
||||||
|
],
|
||||||
|
cwd=os.path.join(os.path.dirname(__file__), "AI_Diplomacy"),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
# Give server time to start
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
logger.info(
|
||||||
|
f"Diplomacy server started at http://localhost:{self.config.diplomacy_server_port}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start Diplomacy server: {e}")
|
||||||
|
|
||||||
|
async def collect_trajectories(
|
||||||
|
self, item: Item
|
||||||
|
) -> Tuple[ScoredDataGroup, List[Item]]:
|
||||||
|
"""
|
||||||
|
Run parallel Diplomacy games and collect all trajectories.
|
||||||
|
|
||||||
|
This implements the key RL training pattern:
|
||||||
|
1. Run group_size parallel games with the same seed
|
||||||
|
2. Each game explores different action sequences
|
||||||
|
3. Score each trajectory based on game outcome
|
||||||
|
4. Return all trajectories as a ScoredDataGroup for training
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
f"[DiplomacyEnvMinimal] collect_trajectories called with item: {item}"
|
||||||
|
)
|
||||||
|
base_game_id = item.get("game_id", f"game-{int(time.time())}")
|
||||||
|
seed = item.get("seed", random.randint(0, 1_000_000))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Starting {self.config.group_size} parallel games with seed {seed}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run parallel games w/ same seed
|
||||||
|
game_tasks = []
|
||||||
|
for i in range(self.config.group_size):
|
||||||
|
game_id = f"{base_game_id}-{i}"
|
||||||
|
task = self._run_single_game(game_id, seed, trajectory_id=i)
|
||||||
|
game_tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*game_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
scored_items = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Game {i} failed: {result}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result and result[0]:
|
||||||
|
scored_items.append(result[0])
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"[DiplomacyEnvMinimal] Collected {len(scored_items)} scored items"
|
||||||
|
)
|
||||||
|
if not scored_items:
|
||||||
|
logger.error("No valid trajectories collected")
|
||||||
|
return (
|
||||||
|
ScoredDataGroup(
|
||||||
|
tokens=[],
|
||||||
|
masks=[],
|
||||||
|
scores=[],
|
||||||
|
messages=[],
|
||||||
|
advantages=None,
|
||||||
|
ref_logprobs=None,
|
||||||
|
group_overrides={},
|
||||||
|
overrides=None,
|
||||||
|
images=None,
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
sdg = ScoredDataGroup(
|
||||||
|
tokens=[],
|
||||||
|
masks=[],
|
||||||
|
scores=[],
|
||||||
|
messages=[],
|
||||||
|
advantages=None,
|
||||||
|
ref_logprobs=None,
|
||||||
|
group_overrides={},
|
||||||
|
overrides=None,
|
||||||
|
images=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for scored_item in scored_items:
|
||||||
|
sdg["tokens"].append(scored_item["tokens"])
|
||||||
|
sdg["masks"].append(scored_item["masks"])
|
||||||
|
sdg["scores"].append(scored_item["scores"])
|
||||||
|
if self.config.include_messages and scored_item.get("messages"):
|
||||||
|
sdg["messages"].append(scored_item["messages"])
|
||||||
|
|
||||||
|
logger.info(f"Collected {len(scored_items)} trajectories")
|
||||||
|
logger.warning(
|
||||||
|
f"[DiplomacyEnvMinimal] Returning ScoredDataGroup with {len(sdg['tokens'])} "
|
||||||
|
f"tokens, {len(sdg['scores'])} scores"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"[DiplomacyEnvMinimal] First few scores: {sdg['scores'][:5] if sdg['scores'] else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(self.config.group_size):
|
||||||
|
game_id = f"{base_game_id}-{i}"
|
||||||
|
if game_id in self.active_games:
|
||||||
|
del self.active_games[game_id]
|
||||||
|
try:
|
||||||
|
await self.queue_manager.remove_game_queues(game_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error cleaning up queues for {game_id}: {e}")
|
||||||
|
|
||||||
|
return sdg, []
|
||||||
|
|
||||||
|
async def _run_single_game(
|
||||||
|
self, game_id: str, seed: int, trajectory_id: int
|
||||||
|
) -> Tuple[Optional[ScoredDataItem], None]:
|
||||||
|
"""
|
||||||
|
Run a single Diplomacy game and return scored trajectory.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
queue_pair = await self.queue_manager.create_game_queues(game_id)
|
||||||
|
|
||||||
|
self.active_games[game_id] = {
|
||||||
|
"queue_pair": queue_pair,
|
||||||
|
"start_time": time.time(),
|
||||||
|
"interactions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
token = current_game_context.set(game_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
game_result = await self._run_diplomacy_game(
|
||||||
|
game_id, seed, trajectory_id
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
current_game_context.reset(token)
|
||||||
|
|
||||||
|
if not game_result:
|
||||||
|
logger.error(f"Game {game_id} failed to complete")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
score = self._calculate_score(game_result, self.config.training_power)
|
||||||
|
|
||||||
|
interactions = get_game_interactions(game_id)
|
||||||
|
|
||||||
|
training_interactions = [
|
||||||
|
i for i in interactions if i.get("power") == self.config.training_power
|
||||||
|
]
|
||||||
|
|
||||||
|
if training_interactions:
|
||||||
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
|
for interaction in training_interactions:
|
||||||
|
messages.append({"role": "user", "content": interaction["prompt"]})
|
||||||
|
messages.append(
|
||||||
|
{"role": "assistant", "content": interaction["response"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Collected {len(training_interactions)} interactions for {self.config.training_power}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No interactions found for {self.config.training_power} in game {game_id}"
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self.system_prompt},
|
||||||
|
{"role": "user", "content": f"Playing Diplomacy game {game_id}"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"Game completed with score {score:.2f}",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
clear_game_interactions(game_id)
|
||||||
|
|
||||||
|
tokenization_result = tokenize_for_trainer(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
chat=messages,
|
||||||
|
train_on_all_assistant_turns=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored_data_item = ScoredDataItem(
|
||||||
|
messages=messages if self.config.include_messages else None,
|
||||||
|
tokens=tokenization_result["tokens"],
|
||||||
|
masks=tokenization_result["masks"],
|
||||||
|
scores=score,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.game_outcomes_buffer.append(
|
||||||
|
{
|
||||||
|
"game_id": game_id,
|
||||||
|
"score": score,
|
||||||
|
"winner": game_result.get("winner"),
|
||||||
|
"turns": game_result.get("turns_played", 0),
|
||||||
|
"final_centers": game_result.get("final_centers", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return scored_data_item, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in game {game_id}: {e}", exc_info=True)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _run_diplomacy_game(
|
||||||
|
self, game_id: str, seed: int, trajectory_id: int = 0
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""Run a Diplomacy game using AI_Diplomacy's lm_game module."""
|
||||||
|
|
||||||
|
game_output_path = os.path.join(self.config.game_logs_dir, f"{game_id}.json")
|
||||||
|
|
||||||
|
models = []
|
||||||
|
opponent_idx = 0
|
||||||
|
|
||||||
|
for power in POWERS:
|
||||||
|
if power == self.config.training_power:
|
||||||
|
models.append("atropos-training-policy")
|
||||||
|
else:
|
||||||
|
models.append(
|
||||||
|
self.opponent_models[opponent_idx % len(self.opponent_models)]
|
||||||
|
)
|
||||||
|
opponent_idx += 1
|
||||||
|
|
||||||
|
original_argv = sys.argv
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.argv = [
|
||||||
|
"lm_game.py",
|
||||||
|
"--models",
|
||||||
|
",".join(models),
|
||||||
|
"--max_year",
|
||||||
|
str(1900 + self.config.max_game_turns),
|
||||||
|
"--output",
|
||||||
|
game_output_path,
|
||||||
|
"--seed",
|
||||||
|
str(seed + trajectory_id),
|
||||||
|
"--num_negotiation_rounds",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
|
||||||
|
await lm_game.main()
|
||||||
|
|
||||||
|
actual_game_file = os.path.join(game_output_path, "lmvsgame.json")
|
||||||
|
if os.path.exists(actual_game_file):
|
||||||
|
with open(actual_game_file, "r") as f:
|
||||||
|
saved_game = json.load(f)
|
||||||
|
|
||||||
|
phases = saved_game.get("phases", [])
|
||||||
|
last_phase = phases[-1] if phases else {}
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"winner": saved_game.get("winner"),
|
||||||
|
"turns_played": len(phases),
|
||||||
|
"final_centers": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for power in POWERS:
|
||||||
|
centers = (
|
||||||
|
last_phase.get("state", {}).get("centers", {}).get(power, [])
|
||||||
|
)
|
||||||
|
result["final_centers"][power] = len(centers)
|
||||||
|
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error(f"Game output not found: {actual_game_file}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
sys.argv = original_argv
|
||||||
|
|
||||||
|
def _calculate_score(self, game_result: Dict, power: str) -> float:
|
||||||
|
"""Calculate score for a power based on game outcome."""
|
||||||
|
start_centers = STARTING_SUPPLY_CENTERS[power]
|
||||||
|
end_centers = game_result["final_centers"].get(power, 0)
|
||||||
|
center_score = (end_centers - start_centers) * self.config.supply_center_weight
|
||||||
|
|
||||||
|
survival_score = self.config.survival_bonus * game_result["turns_played"]
|
||||||
|
|
||||||
|
win_score = self.config.win_bonus if game_result["winner"] == power else 0.0
|
||||||
|
|
||||||
|
total_score = center_score + survival_score + win_score
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{power} score: centers={center_score:.2f}, "
|
||||||
|
f"survival={survival_score:.2f}, win={win_score:.2f}, "
|
||||||
|
f"total={total_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_score
|
||||||
|
|
||||||
|
async def get_next_item(self) -> Item:
|
||||||
|
"""Generate configuration for the next game."""
|
||||||
|
return {
|
||||||
|
"game_id": f"game-{int(time.time())}-{random.randint(1000, 9999)}",
|
||||||
|
"seed": random.randint(0, 1_000_000),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def evaluate(self, *args, **kwargs):
|
||||||
|
"""Run evaluation games."""
|
||||||
|
logger.info(f"Starting evaluation with {self.config.eval_episodes} episodes")
|
||||||
|
|
||||||
|
eval_scores = []
|
||||||
|
wins = 0
|
||||||
|
|
||||||
|
for i in range(self.config.eval_episodes):
|
||||||
|
item = await self.get_next_item()
|
||||||
|
item["is_eval"] = True
|
||||||
|
|
||||||
|
scored_data_group, _ = await self.collect_trajectories(item)
|
||||||
|
if scored_data_group and scored_data_group["scores"]:
|
||||||
|
avg_score = sum(scored_data_group["scores"]) / len(
|
||||||
|
scored_data_group["scores"]
|
||||||
|
)
|
||||||
|
eval_scores.append(avg_score)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.game_outcomes_buffer
|
||||||
|
and self.game_outcomes_buffer[-1]["winner"]
|
||||||
|
== self.config.training_power
|
||||||
|
):
|
||||||
|
wins += 1
|
||||||
|
|
||||||
|
if eval_scores:
|
||||||
|
avg_score = sum(eval_scores) / len(eval_scores)
|
||||||
|
win_rate = wins / self.config.eval_episodes
|
||||||
|
|
||||||
|
self.eval_metrics_custom = [
|
||||||
|
(f"{self.name}_eval/avg_score", avg_score),
|
||||||
|
(f"{self.name}_eval/win_rate", win_rate),
|
||||||
|
(f"{self.name}_eval/num_completed", len(eval_scores)),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Evaluation completed: avg_score={avg_score:.2f}, win_rate={win_rate:.2%}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
||||||
|
"""Log metrics to wandb."""
|
||||||
|
if wandb_metrics is None:
|
||||||
|
wandb_metrics = {}
|
||||||
|
|
||||||
|
if self.game_outcomes_buffer:
|
||||||
|
total_games = len(self.game_outcomes_buffer)
|
||||||
|
avg_score = sum(g["score"] for g in self.game_outcomes_buffer) / total_games
|
||||||
|
wins = sum(
|
||||||
|
1
|
||||||
|
for g in self.game_outcomes_buffer
|
||||||
|
if g["winner"] == self.config.training_power
|
||||||
|
)
|
||||||
|
win_rate = wins / total_games
|
||||||
|
avg_turns = sum(g["turns"] for g in self.game_outcomes_buffer) / total_games
|
||||||
|
|
||||||
|
wandb_metrics.update(
|
||||||
|
{
|
||||||
|
f"{self.name}_train/avg_score": avg_score,
|
||||||
|
f"{self.name}_train/win_rate": win_rate,
|
||||||
|
f"{self.name}_train/avg_turns": avg_turns,
|
||||||
|
f"{self.name}_train/num_games": total_games,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.game_outcomes_buffer = []
|
||||||
|
|
||||||
|
for key, value in self.eval_metrics_custom:
|
||||||
|
wandb_metrics[key] = value
|
||||||
|
self.eval_metrics_custom = []
|
||||||
|
|
||||||
|
await super().wandb_log(wandb_metrics)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Clean up server process on exit."""
|
||||||
|
if self.game_server_process:
|
||||||
|
self.game_server_process.terminate()
|
||||||
|
self.game_server_process.wait()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
DiplomacyEnvMinimal.cli()
|
||||||
218
environments/game_environments/diplomacy_environment/diplomacy_local_server.py
Executable file
218
environments/game_environments/diplomacy_environment/diplomacy_local_server.py
Executable file
|
|
@ -0,0 +1,218 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Local test server for the minimal Diplomacy environment.
|
||||||
|
|
||||||
|
This script runs the full AI_Diplomacy game with real OpenAI models
|
||||||
|
to test the AtroposClient proxy integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from atroposlib.envs.base import APIServerConfig, EvalHandlingEnum
|
||||||
|
from environments.game_environments.diplomacy_environment.diplomacy_env_minimal import (
|
||||||
|
DiplomacyEnvMinimal,
|
||||||
|
DiplomacyEnvMinimalConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run Diplomacy games for testing the minimal environment."""
|
||||||
|
logger.info("Starting Diplomacy minimal environment local test runner")
|
||||||
|
|
||||||
|
# Check for OpenRouter API key
|
||||||
|
if not os.getenv("OPENROUTER_API_KEY"):
|
||||||
|
logger.error(
|
||||||
|
"OPENROUTER_API_KEY not found. Please set it in your environment or .env file"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Configure environment - using OpenRouter model
|
||||||
|
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
openrouter_model = (
|
||||||
|
f"openai:openai/gpt-oss-120b@https://openrouter.ai/api/v1#{openrouter_api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create list of opponent models (6 powers besides training power)
|
||||||
|
opponent_models = [openrouter_model] * 6
|
||||||
|
|
||||||
|
env_config = DiplomacyEnvMinimalConfig(
|
||||||
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||||
|
group_size=2, # Run 2 parallel games
|
||||||
|
use_wandb=False,
|
||||||
|
wandb_name="diplomacy_minimal_local_test",
|
||||||
|
max_num_workers=1,
|
||||||
|
rollout_server_url="http://localhost:8000",
|
||||||
|
total_steps=1,
|
||||||
|
batch_size=1,
|
||||||
|
steps_per_eval=0,
|
||||||
|
max_token_length=4096,
|
||||||
|
inference_weight=1.0,
|
||||||
|
data_path_to_save_groups=None,
|
||||||
|
eval_handling=EvalHandlingEnum.NONE,
|
||||||
|
eval_limit_ratio=0.0,
|
||||||
|
max_game_turns=5, # Short games for testing
|
||||||
|
training_power="FRANCE", # Which power we're training
|
||||||
|
include_messages=True, # Include messages for debugging
|
||||||
|
eval_episodes=0,
|
||||||
|
start_diplomacy_server=True, # Let the env start the server
|
||||||
|
save_game_logs=True,
|
||||||
|
game_logs_dir="./test_game_logs",
|
||||||
|
opponent_models=opponent_models, # Use OpenRouter for all opponents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure server - using 4 servers to match SLURM setup
|
||||||
|
# For local testing, we'll simulate this with the same OpenRouter endpoint
|
||||||
|
server_configs = [
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="openai/gpt-oss-120b", # Using the OpenRouter model
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
|
num_requests_for_eval=0,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="openai/gpt-oss-120b",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
|
num_requests_for_eval=0,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="openai/gpt-oss-120b",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
|
num_requests_for_eval=0,
|
||||||
|
),
|
||||||
|
APIServerConfig(
|
||||||
|
model_name="openai/gpt-oss-120b",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
|
num_requests_for_eval=0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Using OpenRouter openai/gpt-oss-120b for Diplomacy test")
|
||||||
|
logger.debug(f"Env Config: {env_config}")
|
||||||
|
logger.debug(f"Server Configs: {server_configs}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = DiplomacyEnvMinimal(
|
||||||
|
config=env_config,
|
||||||
|
server_configs=server_configs,
|
||||||
|
slurm=False,
|
||||||
|
testing=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to initialize DiplomacyEnvMinimal: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Running test games")
|
||||||
|
try:
|
||||||
|
await env.setup()
|
||||||
|
|
||||||
|
# Get number of episodes from command line or default
|
||||||
|
import sys
|
||||||
|
|
||||||
|
num_episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 3
|
||||||
|
|
||||||
|
# Track statistics
|
||||||
|
episode_results = []
|
||||||
|
|
||||||
|
for episode_num in range(num_episodes):
|
||||||
|
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
|
||||||
|
|
||||||
|
item = await env.get_next_item()
|
||||||
|
logger.info(f"Game ID: {item['game_id']}, Seed: {item['seed']}")
|
||||||
|
|
||||||
|
# Collect trajectories (will run group_size parallel games)
|
||||||
|
scored_data_group, _ = await env.collect_trajectories(item)
|
||||||
|
|
||||||
|
if scored_data_group and scored_data_group["scores"]:
|
||||||
|
avg_score = sum(scored_data_group["scores"]) / len(
|
||||||
|
scored_data_group["scores"]
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Collected {len(scored_data_group['scores'])} trajectories with average score: {avg_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get game outcomes from buffer
|
||||||
|
if env.game_outcomes_buffer:
|
||||||
|
latest_outcomes = env.game_outcomes_buffer[-env.config.group_size :]
|
||||||
|
for i, outcome in enumerate(latest_outcomes):
|
||||||
|
logger.info(
|
||||||
|
f" Game {i}: Score={outcome['score']:.2f}, "
|
||||||
|
f"Winner={outcome['winner']}, "
|
||||||
|
f"Turns={outcome['turns']}, "
|
||||||
|
f"Centers={outcome['final_centers'].get(env.config.training_power, 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_results.append(
|
||||||
|
{
|
||||||
|
"episode": episode_num + 1,
|
||||||
|
"score": avg_score,
|
||||||
|
"outcomes": latest_outcomes if env.game_outcomes_buffer else [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error("Failed to collect trajectory")
|
||||||
|
episode_results.append(
|
||||||
|
{
|
||||||
|
"episode": episode_num + 1,
|
||||||
|
"score": 0.0,
|
||||||
|
"outcomes": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print overall statistics
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("OVERALL RESULTS SUMMARY")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"Total episodes: {num_episodes}")
|
||||||
|
logger.info(f"Group size: {env.config.group_size} games per episode")
|
||||||
|
logger.info(f"Training power: {env.config.training_power}")
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
if episode_results:
|
||||||
|
avg_score = sum(ep["score"] for ep in episode_results) / len(
|
||||||
|
episode_results
|
||||||
|
)
|
||||||
|
logger.info(f"\nAverage trajectory score: {avg_score:.2f}")
|
||||||
|
|
||||||
|
# Count wins
|
||||||
|
total_games = 0
|
||||||
|
wins = 0
|
||||||
|
for ep in episode_results:
|
||||||
|
for outcome in ep["outcomes"]:
|
||||||
|
total_games += 1
|
||||||
|
if outcome["winner"] == env.config.training_power:
|
||||||
|
wins += 1
|
||||||
|
|
||||||
|
if total_games > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Win rate: {wins}/{total_games} ({100*wins/total_games:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Average supply centers
|
||||||
|
total_centers = sum(
|
||||||
|
outcome["final_centers"].get(env.config.training_power, 0)
|
||||||
|
for ep in episode_results
|
||||||
|
for outcome in ep["outcomes"]
|
||||||
|
)
|
||||||
|
avg_centers = total_centers / total_games
|
||||||
|
logger.info(f"Average final supply centers: {avg_centers:.1f}")
|
||||||
|
|
||||||
|
logger.info("\nTest completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"An error occurred during test: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""
|
||||||
|
Queue Manager for Diplomacy Environment
|
||||||
|
|
||||||
|
Manages request/response queues between AtroposClient proxies and the environment.
|
||||||
|
Each game gets its own queue pair for isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueuePair:
|
||||||
|
"""A pair of queues for a single game."""
|
||||||
|
|
||||||
|
game_id: str
|
||||||
|
request_queue: asyncio.Queue
|
||||||
|
response_queue: asyncio.Queue
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PolicyRequest:
|
||||||
|
"""Request from proxy to environment for policy sampling."""
|
||||||
|
|
||||||
|
request_id: str
|
||||||
|
game_id: str
|
||||||
|
power: str
|
||||||
|
phase: str
|
||||||
|
prompt: str
|
||||||
|
temperature: float
|
||||||
|
trajectory: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PolicyResponse:
|
||||||
|
"""Response from environment back to proxy."""
|
||||||
|
|
||||||
|
request_id: str
|
||||||
|
response: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
class QueueManager:
|
||||||
|
"""Manages queues for all parallel games."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue_pairs: Dict[str, QueuePair] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def create_game_queues(self, game_id: str) -> QueuePair:
|
||||||
|
"""Create a new queue pair for a game."""
|
||||||
|
async with self._lock:
|
||||||
|
if game_id in self.queue_pairs:
|
||||||
|
logger.warning(f"Queue pair already exists for game {game_id}")
|
||||||
|
return self.queue_pairs[game_id]
|
||||||
|
|
||||||
|
queue_pair = QueuePair(
|
||||||
|
game_id=game_id,
|
||||||
|
request_queue=asyncio.Queue(),
|
||||||
|
response_queue=asyncio.Queue(),
|
||||||
|
)
|
||||||
|
self.queue_pairs[game_id] = queue_pair
|
||||||
|
logger.info(f"Created queue pair for game {game_id}")
|
||||||
|
return queue_pair
|
||||||
|
|
||||||
|
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
|
||||||
|
"""Get queue pair for a game."""
|
||||||
|
return self.queue_pairs.get(game_id)
|
||||||
|
|
||||||
|
async def remove_game_queues(self, game_id: str):
|
||||||
|
"""Remove queues for a completed game."""
|
||||||
|
async with self._lock:
|
||||||
|
if game_id in self.queue_pairs:
|
||||||
|
del self.queue_pairs[game_id]
|
||||||
|
logger.info(f"Removed queue pair for game {game_id}")
|
||||||
|
|
||||||
|
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
|
||||||
|
"""Get all request queues for polling."""
|
||||||
|
return {
|
||||||
|
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def put_request(self, game_id: str, request: PolicyRequest):
|
||||||
|
"""Put a request on the appropriate queue."""
|
||||||
|
queue_pair = self.get_queue_pair(game_id)
|
||||||
|
if queue_pair:
|
||||||
|
await queue_pair.request_queue.put(request)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No queue pair found for game {game_id}")
|
||||||
|
|
||||||
|
async def get_response(self, game_id: str) -> PolicyResponse:
|
||||||
|
"""Get a response from the appropriate queue."""
|
||||||
|
queue_pair = self.get_queue_pair(game_id)
|
||||||
|
if queue_pair:
|
||||||
|
return await queue_pair.response_queue.get()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No queue pair found for game {game_id}")
|
||||||
|
|
||||||
|
async def put_response(self, game_id: str, response: PolicyResponse):
|
||||||
|
"""Put a response on the appropriate queue."""
|
||||||
|
queue_pair = self.get_queue_pair(game_id)
|
||||||
|
if queue_pair:
|
||||||
|
await queue_pair.response_queue.put(response)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No queue pair found for game {game_id}")
|
||||||
|
|
||||||
|
|
||||||
|
_queue_manager = QueueManager()
|
||||||
|
|
||||||
|
|
||||||
|
def get_queue_manager() -> QueueManager:
|
||||||
|
"""Get the global queue manager instance."""
|
||||||
|
return _queue_manager
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Requirements for minimal Diplomacy environment
|
||||||
|
# Only includes dependencies specific to AI_Diplomacy that aren't in main project
|
||||||
|
|
||||||
|
# AI_Diplomacy specific dependencies
|
||||||
|
tornado>=6.0
|
||||||
|
ujson>=5.0
|
||||||
|
websockets>=10.0
|
||||||
|
coloredlogs>=15.0.1
|
||||||
|
|
@ -49,7 +49,7 @@ async def main():
|
||||||
# Configure server - using same model as blackjack example
|
# Configure server - using same model as blackjack example
|
||||||
server_configs = [
|
server_configs = [
|
||||||
APIServerConfig(
|
APIServerConfig(
|
||||||
model_name="gpt-4.1-nano",
|
model_name="gpt-4.1",
|
||||||
base_url="https://api.openai.com/v1",
|
base_url="https://api.openai.com/v1",
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
num_requests_for_eval=0,
|
num_requests_for_eval=0,
|
||||||
|
|
@ -99,7 +99,6 @@ async def main():
|
||||||
item = await env.get_next_item()
|
item = await env.get_next_item()
|
||||||
challenge_name = item["challenge_name"]
|
challenge_name = item["challenge_name"]
|
||||||
challenge_counts[challenge_name] += 1
|
challenge_counts[challenge_name] += 1
|
||||||
|
|
||||||
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
|
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
|
||||||
logger.info(f"Using game: {item}")
|
logger.info(f"Using game: {item}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,6 @@ class TextWorldChallengeRegistry:
|
||||||
def __init__(self, seed: Optional[int] = None):
|
def __init__(self, seed: Optional[int] = None):
|
||||||
self._challenges = self.CHALLENGES.copy()
|
self._challenges = self.CHALLENGES.copy()
|
||||||
self.rng = random.Random(seed)
|
self.rng = random.Random(seed)
|
||||||
|
|
||||||
# Cache for all possible combinations
|
# Cache for all possible combinations
|
||||||
self._all_combinations = None
|
self._all_combinations = None
|
||||||
self._combination_index = 0
|
self._combination_index = 0
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
Subproject commit 0699e2f5074379bb19924b7f64221ccfc8cef6ef
|
Subproject commit 9e79fc84b614b69c04325854335e308cd61d731d
|
||||||
Loading…
Add table
Add a link
Reference in a new issue