mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4fe67e698d
commit
46f0602227
13 changed files with 1317 additions and 4 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -205,3 +205,6 @@ environments/community/word_hunt/word_hunt_rollouts*.html
|
|||
*.z8
|
||||
*.ni
|
||||
*.z5
|
||||
|
||||
# Diplomacy artefacts
|
||||
environments/game_environments/diplomacy_environment/logs/
|
||||
|
|
|
|||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -4,3 +4,6 @@
|
|||
[submodule "environments/reasoning_gym_environment/reasoning-gym"]
|
||||
path = environments/reasoning_gym_environment/reasoning-gym
|
||||
url = https://github.com/open-thought/reasoning-gym
|
||||
[submodule "environments/game_environments/diplomacy_environment/AI_Diplomacy"]
|
||||
path = environments/game_environments/diplomacy_environment/AI_Diplomacy
|
||||
url = https://github.com/GoodStartLabs/AI_Diplomacy.git
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 70d4ae2fe029ed38bf823ce70f5d45a7bcc7afbc
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# Minimal Diplomacy Environment
|
||||
|
||||
A simplified Diplomacy RL training environment for Atropos that integrates with AI_Diplomacy.
|
||||
|
||||
## Overview
|
||||
|
||||
This minimal implementation provides:
|
||||
- Basic game integration via AI_Diplomacy submodule
|
||||
- Parallel rollouts with configurable group_size
|
||||
- LLM request interception through AtroposClient proxy
|
||||
- Simple supply center based scoring
|
||||
- No complex features (no GRPO, memory systems, or advanced scoring)
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Atropos Policy Server
|
||||
↓
|
||||
AtroposClientMinimal (proxy)
|
||||
↓
|
||||
AI_Diplomacy Game Engine
|
||||
↓
|
||||
Game Execution
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
cd AI_Diplomacy
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
2. Start your Atropos policy server on port 8000
|
||||
|
||||
3. Run the environment:
|
||||
```bash
|
||||
python diplomacy_env_minimal.py serve
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Key settings in `DiplomacyEnvMinimalConfig`:
|
||||
- `max_game_turns`: Number of game turns (default: 10)
|
||||
- `training_power`: Which power the RL agent controls (default: "FRANCE")
|
||||
- `group_size`: Number of parallel games per trajectory (default: 4)
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Parallel Rollouts**: Each training step runs `group_size` games with the same initial seed
|
||||
2. **LLM Interception**: AtroposClientMinimal intercepts all LLM calls from AI_Diplomacy
|
||||
3. **Trajectory Collection**: Game interactions are collected and scored
|
||||
4. **Best Selection**: The highest scoring trajectory is returned for training
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
Minimal Diplomacy Environment for Atropos RL Training
|
||||
"""
|
||||
|
||||
from .diplomacy_env_minimal import DiplomacyEnvMinimal, DiplomacyEnvMinimalConfig
|
||||
|
||||
__all__ = ["DiplomacyEnvMinimal", "DiplomacyEnvMinimalConfig"]
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
"""
|
||||
|
||||
This is a queue-based proxy that:
|
||||
- Intercepts LLM requests from AI_Diplomacy
|
||||
- Puts them on a queue for the environment to process
|
||||
- Waits for responses from the environment
|
||||
- Returns responses to AI_Diplomacy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import (
|
||||
clients,
|
||||
)
|
||||
from environments.game_environments.diplomacy_environment.queue_manager import (
|
||||
PolicyRequest,
|
||||
QueueManager,
|
||||
get_queue_manager,
|
||||
)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
|
||||
|
||||
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402
|
||||
BaseModelClient,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_game_context = contextvars.ContextVar("current_game_id", default=None)
|
||||
_game_interactions = {}
|
||||
|
||||
|
||||
class AtroposClientMinimal(BaseModelClient):
|
||||
"""
|
||||
Queue-based proxy client that forwards LLM requests through queues.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
queue_manager: Optional[QueueManager] = None,
|
||||
):
|
||||
super().__init__(model_name)
|
||||
self.game_id = current_game_context.get()
|
||||
if not self.game_id:
|
||||
raise ValueError("AtroposClientMinimal created without game context set")
|
||||
|
||||
self.queue_manager = queue_manager or get_queue_manager()
|
||||
|
||||
self.interactions: List[Dict] = []
|
||||
self.current_power: Optional[str] = None
|
||||
self.current_phase: Optional[str] = None
|
||||
|
||||
logger.info(
|
||||
f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}"
|
||||
)
|
||||
|
||||
async def generate_response(self, prompt: str, temperature: float = 0.0) -> str:
|
||||
"""
|
||||
Put request on queue and wait for response from environment.
|
||||
This is the main method AI_Diplomacy calls for all LLM interactions.
|
||||
"""
|
||||
task_type = self._infer_task_type(prompt)
|
||||
power = self._extract_power(prompt)
|
||||
phase = self._extract_phase(prompt)
|
||||
|
||||
if power:
|
||||
self.current_power = power
|
||||
if phase:
|
||||
self.current_phase = phase
|
||||
|
||||
logger.debug(f"Generating response for {self.current_power}: {task_type}")
|
||||
|
||||
try:
|
||||
request_id = str(uuid.uuid4())
|
||||
request = PolicyRequest(
|
||||
request_id=request_id,
|
||||
game_id=self.game_id,
|
||||
power=self.current_power or "UNKNOWN",
|
||||
phase=self.current_phase or "UNKNOWN",
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
trajectory=self.interactions.copy(),
|
||||
)
|
||||
|
||||
await self.queue_manager.put_request(self.game_id, request)
|
||||
logger.debug(f"Put request {request_id} on queue for game {self.game_id}")
|
||||
|
||||
response = await self.queue_manager.get_response(self.game_id)
|
||||
|
||||
if response.request_id != request_id:
|
||||
logger.warning(
|
||||
f"Response ID mismatch: expected {request_id}, got {response.request_id}"
|
||||
)
|
||||
|
||||
response_text = response.response
|
||||
|
||||
# Track interaction
|
||||
interaction = {
|
||||
"power": self.current_power,
|
||||
"phase": self.current_phase,
|
||||
"task_type": task_type,
|
||||
"prompt": prompt,
|
||||
"response": response_text,
|
||||
"metadata": response.metadata, # Store any additional info from environment
|
||||
}
|
||||
self.interactions.append(interaction)
|
||||
|
||||
if self.game_id not in _game_interactions:
|
||||
_game_interactions[self.game_id] = []
|
||||
_game_interactions[self.game_id].append(interaction)
|
||||
|
||||
return response_text
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout waiting for response from environment")
|
||||
return self._generate_fallback_response(prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {e}")
|
||||
return self._generate_fallback_response(prompt)
|
||||
|
||||
def _infer_task_type(self, prompt: str) -> str:
|
||||
"""Infer the type of task from the prompt."""
|
||||
prompt_lower = prompt.lower()
|
||||
|
||||
if "orders" in prompt_lower or "submit" in prompt_lower:
|
||||
return "orders"
|
||||
elif "message" in prompt_lower or "negotiate" in prompt_lower:
|
||||
return "negotiation"
|
||||
elif "plan" in prompt_lower or "strategy" in prompt_lower:
|
||||
return "planning"
|
||||
else:
|
||||
return "general"
|
||||
|
||||
def _extract_power(self, prompt: str) -> Optional[str]:
|
||||
"""Extract power name from prompt if mentioned."""
|
||||
for power in [
|
||||
"AUSTRIA",
|
||||
"ENGLAND",
|
||||
"FRANCE",
|
||||
"GERMANY",
|
||||
"ITALY",
|
||||
"RUSSIA",
|
||||
"TURKEY",
|
||||
]:
|
||||
if power in prompt.upper():
|
||||
return power
|
||||
return None
|
||||
|
||||
def _extract_phase(self, prompt: str) -> Optional[str]:
|
||||
"""Extract game phase from prompt if mentioned."""
|
||||
import re
|
||||
|
||||
phase_match = re.search(r"[SF]\d{4}[MRB]", prompt)
|
||||
if phase_match:
|
||||
return phase_match.group()
|
||||
|
||||
verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt)
|
||||
if verbose_match:
|
||||
return verbose_match.group()
|
||||
|
||||
return None
|
||||
|
||||
def _generate_fallback_response(self, prompt: str) -> str:
|
||||
"""Generate a simple fallback response if there's an issue."""
|
||||
task_type = self._infer_task_type(prompt)
|
||||
|
||||
if task_type == "orders":
|
||||
return json.dumps(
|
||||
{
|
||||
"orders": {},
|
||||
"explanations": {"general": "Fallback - no server connected"},
|
||||
}
|
||||
)
|
||||
elif task_type == "negotiation":
|
||||
return json.dumps(
|
||||
{
|
||||
"messages": [],
|
||||
"explanations": {"general": "Fallback - no server connected"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
return "Fallback response - server not available"
|
||||
|
||||
def get_interactions(self) -> List[Dict]:
|
||||
"""Get all tracked interactions for trajectory collection."""
|
||||
return self.interactions
|
||||
|
||||
def clear_interactions(self):
|
||||
"""Clear tracked interactions for a new game."""
|
||||
self.interactions = []
|
||||
self.current_power = None
|
||||
self.current_phase = None
|
||||
|
||||
|
||||
def get_game_interactions(game_id: str) -> List[Dict]:
|
||||
"""Get all interactions for a specific game."""
|
||||
return _game_interactions.get(game_id, [])
|
||||
|
||||
|
||||
def clear_game_interactions(game_id: str):
|
||||
"""Clear interactions for a specific game."""
|
||||
if game_id in _game_interactions:
|
||||
del _game_interactions[game_id]
|
||||
|
||||
|
||||
def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None):
|
||||
"""
|
||||
Register AtroposClientMinimal with AI_Diplomacy's model loading system globally.
|
||||
This should be called ONCE during environment setup.
|
||||
|
||||
Args:
|
||||
queue_manager: Optional queue manager (uses global if not provided)
|
||||
"""
|
||||
|
||||
if hasattr(clients, "_atropos_registered"):
|
||||
logger.info("AtroposClientMinimal already registered globally")
|
||||
return
|
||||
|
||||
clients._original_load_model_client = clients.load_model_client
|
||||
clients._atropos_queue_manager = queue_manager or get_queue_manager()
|
||||
|
||||
def load_model_client_with_atropos(
|
||||
model_id: str, prompts_dir: Optional[str] = None
|
||||
) -> BaseModelClient:
|
||||
if model_id.startswith("atropos-"):
|
||||
logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}")
|
||||
return AtroposClientMinimal(model_id, clients._atropos_queue_manager)
|
||||
else:
|
||||
logger.info(f"Falling back to original loader for {model_id}")
|
||||
return clients._original_load_model_client(model_id, prompts_dir)
|
||||
|
||||
clients.load_model_client = load_model_client_with_atropos
|
||||
clients._atropos_registered = True
|
||||
|
||||
logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test_client():
|
||||
client = AtroposClientMinimal(
|
||||
"atropos-test",
|
||||
{"base_url": "http://localhost:8000", "model_name": "test-model"},
|
||||
)
|
||||
|
||||
test_prompts = [
|
||||
"You are FRANCE. What are your orders for Spring 1901?",
|
||||
"Send a message to ENGLAND about cooperation.",
|
||||
"What is your strategic plan?",
|
||||
]
|
||||
|
||||
for prompt in test_prompts:
|
||||
print(f"\nPrompt: {prompt[:50]}...")
|
||||
response = await client.generate_response(prompt)
|
||||
print(f"Response: {response[:100]}...")
|
||||
|
||||
print(f"\nTracked {len(client.get_interactions())} interactions")
|
||||
await client.close()
|
||||
|
||||
asyncio.run(test_client())
|
||||
|
|
@ -0,0 +1,636 @@
|
|||
"""
|
||||
Minimal Diplomacy Environment for Atropos
|
||||
|
||||
A simplified implementation focusing on:
|
||||
- Basic game integration with AI_Diplomacy
|
||||
- Parallel rollouts with group_size
|
||||
- LLM proxy interception via AtroposClient
|
||||
- Simple supply center based scoring
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from environments.game_environments.diplomacy_environment.atropos_client_minimal import (
|
||||
clear_game_interactions,
|
||||
current_game_context,
|
||||
get_game_interactions,
|
||||
register_atropos_models_globally,
|
||||
)
|
||||
from environments.game_environments.diplomacy_environment.queue_manager import (
|
||||
PolicyRequest,
|
||||
PolicyResponse,
|
||||
get_queue_manager,
|
||||
)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
|
||||
import lm_game # noqa: E402
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POWERS = ["AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"]
|
||||
STARTING_SUPPLY_CENTERS = {
|
||||
"AUSTRIA": 3,
|
||||
"ENGLAND": 3,
|
||||
"FRANCE": 3,
|
||||
"GERMANY": 3,
|
||||
"ITALY": 3,
|
||||
"RUSSIA": 4,
|
||||
"TURKEY": 3,
|
||||
}
|
||||
|
||||
|
||||
class DiplomacyEnvMinimalConfig(BaseEnvConfig):
|
||||
"""Configuration for the minimal Diplomacy environment."""
|
||||
|
||||
env_name: str = "diplomacy_minimal"
|
||||
|
||||
# Game settings (minimal implementation, no randomisation stuff)
|
||||
max_game_turns: int = 10
|
||||
training_power: str = "FRANCE"
|
||||
total_steps: int = 10
|
||||
|
||||
# Scoring
|
||||
supply_center_weight: float = 1.0
|
||||
survival_bonus: float = 0.1
|
||||
win_bonus: float = 5.0
|
||||
|
||||
# Process management
|
||||
diplomacy_server_port: int = 8432
|
||||
start_diplomacy_server: bool = True
|
||||
|
||||
# Logging
|
||||
save_game_logs: bool = True
|
||||
game_logs_dir: str = "./game_logs"
|
||||
|
||||
# Evaluation
|
||||
eval_episodes: int = 10
|
||||
|
||||
# Opponent models (None = use SGLang endpoints from server_configs)
|
||||
opponent_models: Optional[List[str]] = None
|
||||
|
||||
|
||||
class DiplomacyEnvMinimal(BaseEnv):
|
||||
name = "diplomacy_minimal"
|
||||
env_config_cls = DiplomacyEnvMinimalConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiplomacyEnvMinimalConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: DiplomacyEnvMinimalConfig = config
|
||||
self.game_server_process: Optional[subprocess.Popen] = None
|
||||
self.game_outcomes_buffer: List[Dict] = []
|
||||
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
||||
|
||||
self.queue_manager = get_queue_manager()
|
||||
self.active_games: Dict[str, Dict] = {}
|
||||
|
||||
if config.save_game_logs:
|
||||
Path(config.game_logs_dir).mkdir(exist_ok=True)
|
||||
|
||||
self.system_prompt = (
|
||||
f"You are playing Diplomacy as {config.training_power}. "
|
||||
"Analyze the game state and respond with your strategy and orders."
|
||||
)
|
||||
|
||||
if config.opponent_models:
|
||||
self.opponent_models = config.opponent_models
|
||||
else:
|
||||
self.opponent_models = []
|
||||
for server_config in server_configs:
|
||||
model_spec = f"openai:{server_config.model_name}@{server_config.base_url}#{server_config.api_key}"
|
||||
self.opponent_models.append(model_spec)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[DiplomacyEnvMinimalConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration."""
|
||||
env_config = DiplomacyEnvMinimalConfig(
|
||||
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
group_size=4,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=4096,
|
||||
wandb_name=cls.name,
|
||||
steps_per_eval=20,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9005/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9006/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9007/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the environment."""
|
||||
logger.info(f"Setting up {self.name} environment")
|
||||
|
||||
register_atropos_models_globally(self.queue_manager)
|
||||
|
||||
if self.config.start_diplomacy_server:
|
||||
await self._start_diplomacy_server()
|
||||
|
||||
asyncio.create_task(self._poll_request_queues())
|
||||
|
||||
async def _poll_request_queues(self):
|
||||
"""Poll request queues and handle policy requests."""
|
||||
while True:
|
||||
try:
|
||||
for game_id in list(self.active_games.keys()):
|
||||
queue_pair = self.queue_manager.get_queue_pair(game_id)
|
||||
if not queue_pair:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = queue_pair.request_queue.get_nowait()
|
||||
asyncio.create_task(self._handle_policy_request(request))
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in queue polling: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _handle_policy_request(self, request: PolicyRequest):
|
||||
"""Handle a single policy request by sampling from SGLang."""
|
||||
try:
|
||||
logger.info(
|
||||
f"Handling request {request.request_id} for {request.power} in game {request.game_id}"
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
for interaction in request.trajectory:
|
||||
messages.append({"role": "user", "content": interaction["prompt"]})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": interaction["response"]}
|
||||
)
|
||||
|
||||
messages.append({"role": "user", "content": request.prompt})
|
||||
|
||||
async with self.server.dedicated_server() as server:
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
temperature=request.temperature,
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
response_text = response.choices[0].message.content.strip()
|
||||
|
||||
policy_response = PolicyResponse(
|
||||
request_id=request.request_id,
|
||||
response=response_text,
|
||||
metadata={
|
||||
"power": request.power,
|
||||
"phase": request.phase,
|
||||
},
|
||||
)
|
||||
|
||||
await self.queue_manager.put_response(request.game_id, policy_response)
|
||||
logger.debug(f"Sent response for request {request.request_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling policy request: {e}")
|
||||
error_response = PolicyResponse(
|
||||
request_id=request.request_id,
|
||||
response="Error: Failed to generate response",
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
await self.queue_manager.put_response(request.game_id, error_response)
|
||||
|
||||
async def _start_diplomacy_server(self):
|
||||
"""Start the AI_Diplomacy game server."""
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting Diplomacy server on port {self.config.diplomacy_server_port}"
|
||||
)
|
||||
self.game_server_process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"diplomacy.server.run",
|
||||
"--port",
|
||||
str(self.config.diplomacy_server_port),
|
||||
],
|
||||
cwd=os.path.join(os.path.dirname(__file__), "AI_Diplomacy"),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# Give server time to start
|
||||
await asyncio.sleep(3)
|
||||
logger.info(
|
||||
f"Diplomacy server started at http://localhost:{self.config.diplomacy_server_port}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Diplomacy server: {e}")
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[ScoredDataGroup, List[Item]]:
|
||||
"""
|
||||
Run parallel Diplomacy games and collect all trajectories.
|
||||
|
||||
This implements the key RL training pattern:
|
||||
1. Run group_size parallel games with the same seed
|
||||
2. Each game explores different action sequences
|
||||
3. Score each trajectory based on game outcome
|
||||
4. Return all trajectories as a ScoredDataGroup for training
|
||||
"""
|
||||
logger.warning(
|
||||
f"[DiplomacyEnvMinimal] collect_trajectories called with item: {item}"
|
||||
)
|
||||
base_game_id = item.get("game_id", f"game-{int(time.time())}")
|
||||
seed = item.get("seed", random.randint(0, 1_000_000))
|
||||
|
||||
logger.info(
|
||||
f"Starting {self.config.group_size} parallel games with seed {seed}"
|
||||
)
|
||||
|
||||
# Run parallel games w/ same seed
|
||||
game_tasks = []
|
||||
for i in range(self.config.group_size):
|
||||
game_id = f"{base_game_id}-{i}"
|
||||
task = self._run_single_game(game_id, seed, trajectory_id=i)
|
||||
game_tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*game_tasks, return_exceptions=True)
|
||||
|
||||
scored_items = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Game {i} failed: {result}")
|
||||
continue
|
||||
|
||||
if result and result[0]:
|
||||
scored_items.append(result[0])
|
||||
|
||||
logger.warning(
|
||||
f"[DiplomacyEnvMinimal] Collected {len(scored_items)} scored items"
|
||||
)
|
||||
if not scored_items:
|
||||
logger.error("No valid trajectories collected")
|
||||
return (
|
||||
ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
advantages=None,
|
||||
ref_logprobs=None,
|
||||
group_overrides={},
|
||||
overrides=None,
|
||||
images=None,
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
sdg = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
advantages=None,
|
||||
ref_logprobs=None,
|
||||
group_overrides={},
|
||||
overrides=None,
|
||||
images=None,
|
||||
)
|
||||
|
||||
for scored_item in scored_items:
|
||||
sdg["tokens"].append(scored_item["tokens"])
|
||||
sdg["masks"].append(scored_item["masks"])
|
||||
sdg["scores"].append(scored_item["scores"])
|
||||
if self.config.include_messages and scored_item.get("messages"):
|
||||
sdg["messages"].append(scored_item["messages"])
|
||||
|
||||
logger.info(f"Collected {len(scored_items)} trajectories")
|
||||
logger.warning(
|
||||
f"[DiplomacyEnvMinimal] Returning ScoredDataGroup with {len(sdg['tokens'])} "
|
||||
f"tokens, {len(sdg['scores'])} scores"
|
||||
)
|
||||
logger.warning(
|
||||
f"[DiplomacyEnvMinimal] First few scores: {sdg['scores'][:5] if sdg['scores'] else 'None'}"
|
||||
)
|
||||
|
||||
for i in range(self.config.group_size):
|
||||
game_id = f"{base_game_id}-{i}"
|
||||
if game_id in self.active_games:
|
||||
del self.active_games[game_id]
|
||||
try:
|
||||
await self.queue_manager.remove_game_queues(game_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error cleaning up queues for {game_id}: {e}")
|
||||
|
||||
return sdg, []
|
||||
|
||||
async def _run_single_game(
|
||||
self, game_id: str, seed: int, trajectory_id: int
|
||||
) -> Tuple[Optional[ScoredDataItem], None]:
|
||||
"""
|
||||
Run a single Diplomacy game and return scored trajectory.
|
||||
"""
|
||||
try:
|
||||
queue_pair = await self.queue_manager.create_game_queues(game_id)
|
||||
|
||||
self.active_games[game_id] = {
|
||||
"queue_pair": queue_pair,
|
||||
"start_time": time.time(),
|
||||
"interactions": [],
|
||||
}
|
||||
|
||||
token = current_game_context.set(game_id)
|
||||
|
||||
try:
|
||||
game_result = await self._run_diplomacy_game(
|
||||
game_id, seed, trajectory_id
|
||||
)
|
||||
finally:
|
||||
current_game_context.reset(token)
|
||||
|
||||
if not game_result:
|
||||
logger.error(f"Game {game_id} failed to complete")
|
||||
return None, None
|
||||
|
||||
score = self._calculate_score(game_result, self.config.training_power)
|
||||
|
||||
interactions = get_game_interactions(game_id)
|
||||
|
||||
training_interactions = [
|
||||
i for i in interactions if i.get("power") == self.config.training_power
|
||||
]
|
||||
|
||||
if training_interactions:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
for interaction in training_interactions:
|
||||
messages.append({"role": "user", "content": interaction["prompt"]})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": interaction["response"]}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Collected {len(training_interactions)} interactions for {self.config.training_power}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No interactions found for {self.config.training_power} in game {game_id}"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": f"Playing Diplomacy game {game_id}"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Game completed with score {score:.2f}",
|
||||
},
|
||||
]
|
||||
|
||||
clear_game_interactions(game_id)
|
||||
|
||||
tokenization_result = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=messages,
|
||||
train_on_all_assistant_turns=True,
|
||||
)
|
||||
|
||||
scored_data_item = ScoredDataItem(
|
||||
messages=messages if self.config.include_messages else None,
|
||||
tokens=tokenization_result["tokens"],
|
||||
masks=tokenization_result["masks"],
|
||||
scores=score,
|
||||
)
|
||||
|
||||
self.game_outcomes_buffer.append(
|
||||
{
|
||||
"game_id": game_id,
|
||||
"score": score,
|
||||
"winner": game_result.get("winner"),
|
||||
"turns": game_result.get("turns_played", 0),
|
||||
"final_centers": game_result.get("final_centers", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return scored_data_item, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in game {game_id}: {e}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
async def _run_diplomacy_game(
|
||||
self, game_id: str, seed: int, trajectory_id: int = 0
|
||||
) -> Optional[Dict]:
|
||||
"""Run a Diplomacy game using AI_Diplomacy's lm_game module."""
|
||||
|
||||
game_output_path = os.path.join(self.config.game_logs_dir, f"{game_id}.json")
|
||||
|
||||
models = []
|
||||
opponent_idx = 0
|
||||
|
||||
for power in POWERS:
|
||||
if power == self.config.training_power:
|
||||
models.append("atropos-training-policy")
|
||||
else:
|
||||
models.append(
|
||||
self.opponent_models[opponent_idx % len(self.opponent_models)]
|
||||
)
|
||||
opponent_idx += 1
|
||||
|
||||
original_argv = sys.argv
|
||||
|
||||
try:
|
||||
sys.argv = [
|
||||
"lm_game.py",
|
||||
"--models",
|
||||
",".join(models),
|
||||
"--max_year",
|
||||
str(1900 + self.config.max_game_turns),
|
||||
"--output",
|
||||
game_output_path,
|
||||
"--seed",
|
||||
str(seed + trajectory_id),
|
||||
"--num_negotiation_rounds",
|
||||
"0",
|
||||
]
|
||||
|
||||
await lm_game.main()
|
||||
|
||||
actual_game_file = os.path.join(game_output_path, "lmvsgame.json")
|
||||
if os.path.exists(actual_game_file):
|
||||
with open(actual_game_file, "r") as f:
|
||||
saved_game = json.load(f)
|
||||
|
||||
phases = saved_game.get("phases", [])
|
||||
last_phase = phases[-1] if phases else {}
|
||||
|
||||
result = {
|
||||
"winner": saved_game.get("winner"),
|
||||
"turns_played": len(phases),
|
||||
"final_centers": {},
|
||||
}
|
||||
|
||||
for power in POWERS:
|
||||
centers = (
|
||||
last_phase.get("state", {}).get("centers", {}).get(power, [])
|
||||
)
|
||||
result["final_centers"][power] = len(centers)
|
||||
|
||||
return result
|
||||
else:
|
||||
logger.error(f"Game output not found: {actual_game_file}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
def _calculate_score(self, game_result: Dict, power: str) -> float:
|
||||
"""Calculate score for a power based on game outcome."""
|
||||
start_centers = STARTING_SUPPLY_CENTERS[power]
|
||||
end_centers = game_result["final_centers"].get(power, 0)
|
||||
center_score = (end_centers - start_centers) * self.config.supply_center_weight
|
||||
|
||||
survival_score = self.config.survival_bonus * game_result["turns_played"]
|
||||
|
||||
win_score = self.config.win_bonus if game_result["winner"] == power else 0.0
|
||||
|
||||
total_score = center_score + survival_score + win_score
|
||||
|
||||
logger.info(
|
||||
f"{power} score: centers={center_score:.2f}, "
|
||||
f"survival={survival_score:.2f}, win={win_score:.2f}, "
|
||||
f"total={total_score:.2f}"
|
||||
)
|
||||
|
||||
return total_score
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Generate configuration for the next game."""
|
||||
return {
|
||||
"game_id": f"game-{int(time.time())}-{random.randint(1000, 9999)}",
|
||||
"seed": random.randint(0, 1_000_000),
|
||||
}
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Run evaluation games."""
|
||||
logger.info(f"Starting evaluation with {self.config.eval_episodes} episodes")
|
||||
|
||||
eval_scores = []
|
||||
wins = 0
|
||||
|
||||
for i in range(self.config.eval_episodes):
|
||||
item = await self.get_next_item()
|
||||
item["is_eval"] = True
|
||||
|
||||
scored_data_group, _ = await self.collect_trajectories(item)
|
||||
if scored_data_group and scored_data_group["scores"]:
|
||||
avg_score = sum(scored_data_group["scores"]) / len(
|
||||
scored_data_group["scores"]
|
||||
)
|
||||
eval_scores.append(avg_score)
|
||||
|
||||
if (
|
||||
self.game_outcomes_buffer
|
||||
and self.game_outcomes_buffer[-1]["winner"]
|
||||
== self.config.training_power
|
||||
):
|
||||
wins += 1
|
||||
|
||||
if eval_scores:
|
||||
avg_score = sum(eval_scores) / len(eval_scores)
|
||||
win_rate = wins / self.config.eval_episodes
|
||||
|
||||
self.eval_metrics_custom = [
|
||||
(f"{self.name}_eval/avg_score", avg_score),
|
||||
(f"{self.name}_eval/win_rate", win_rate),
|
||||
(f"{self.name}_eval/num_completed", len(eval_scores)),
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Evaluation completed: avg_score={avg_score:.2f}, win_rate={win_rate:.2%}"
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
||||
"""Log metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.game_outcomes_buffer:
|
||||
total_games = len(self.game_outcomes_buffer)
|
||||
avg_score = sum(g["score"] for g in self.game_outcomes_buffer) / total_games
|
||||
wins = sum(
|
||||
1
|
||||
for g in self.game_outcomes_buffer
|
||||
if g["winner"] == self.config.training_power
|
||||
)
|
||||
win_rate = wins / total_games
|
||||
avg_turns = sum(g["turns"] for g in self.game_outcomes_buffer) / total_games
|
||||
|
||||
wandb_metrics.update(
|
||||
{
|
||||
f"{self.name}_train/avg_score": avg_score,
|
||||
f"{self.name}_train/win_rate": win_rate,
|
||||
f"{self.name}_train/avg_turns": avg_turns,
|
||||
f"{self.name}_train/num_games": total_games,
|
||||
}
|
||||
)
|
||||
|
||||
self.game_outcomes_buffer = []
|
||||
|
||||
for key, value in self.eval_metrics_custom:
|
||||
wandb_metrics[key] = value
|
||||
self.eval_metrics_custom = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up server process on exit."""
|
||||
if self.game_server_process:
|
||||
self.game_server_process.terminate()
|
||||
self.game_server_process.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DiplomacyEnvMinimal.cli()
|
||||
218
environments/game_environments/diplomacy_environment/diplomacy_local_server.py
Executable file
218
environments/game_environments/diplomacy_environment/diplomacy_local_server.py
Executable file
|
|
@ -0,0 +1,218 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Local test server for the minimal Diplomacy environment.
|
||||
|
||||
This script runs the full AI_Diplomacy game with real OpenAI models
|
||||
to test the AtroposClient proxy integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, EvalHandlingEnum
|
||||
from environments.game_environments.diplomacy_environment.diplomacy_env_minimal import (
|
||||
DiplomacyEnvMinimal,
|
||||
DiplomacyEnvMinimalConfig,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run Diplomacy games for testing the minimal environment."""
|
||||
logger.info("Starting Diplomacy minimal environment local test runner")
|
||||
|
||||
# Check for OpenRouter API key
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
logger.error(
|
||||
"OPENROUTER_API_KEY not found. Please set it in your environment or .env file"
|
||||
)
|
||||
return
|
||||
|
||||
# Configure environment - using OpenRouter model
|
||||
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
openrouter_model = (
|
||||
f"openai:openai/gpt-oss-120b@https://openrouter.ai/api/v1#{openrouter_api_key}"
|
||||
)
|
||||
|
||||
# Create list of opponent models (6 powers besides training power)
|
||||
opponent_models = [openrouter_model] * 6
|
||||
|
||||
env_config = DiplomacyEnvMinimalConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=2, # Run 2 parallel games
|
||||
use_wandb=False,
|
||||
wandb_name="diplomacy_minimal_local_test",
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=0,
|
||||
max_token_length=4096,
|
||||
inference_weight=1.0,
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.NONE,
|
||||
eval_limit_ratio=0.0,
|
||||
max_game_turns=5, # Short games for testing
|
||||
training_power="FRANCE", # Which power we're training
|
||||
include_messages=True, # Include messages for debugging
|
||||
eval_episodes=0,
|
||||
start_diplomacy_server=True, # Let the env start the server
|
||||
save_game_logs=True,
|
||||
game_logs_dir="./test_game_logs",
|
||||
opponent_models=opponent_models, # Use OpenRouter for all opponents
|
||||
)
|
||||
|
||||
# Configure server - using 4 servers to match SLURM setup
|
||||
# For local testing, we'll simulate this with the same OpenRouter endpoint
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="openai/gpt-oss-120b", # Using the OpenRouter model
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="openai/gpt-oss-120b",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="openai/gpt-oss-120b",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="openai/gpt-oss-120b",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
),
|
||||
]
|
||||
|
||||
logger.info("Using OpenRouter openai/gpt-oss-120b for Diplomacy test")
|
||||
logger.debug(f"Env Config: {env_config}")
|
||||
logger.debug(f"Server Configs: {server_configs}")
|
||||
|
||||
try:
|
||||
env = DiplomacyEnvMinimal(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to initialize DiplomacyEnvMinimal: {e}")
|
||||
return
|
||||
|
||||
logger.info("Running test games")
|
||||
try:
|
||||
await env.setup()
|
||||
|
||||
# Get number of episodes from command line or default
|
||||
import sys
|
||||
|
||||
num_episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 3
|
||||
|
||||
# Track statistics
|
||||
episode_results = []
|
||||
|
||||
for episode_num in range(num_episodes):
|
||||
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
|
||||
|
||||
item = await env.get_next_item()
|
||||
logger.info(f"Game ID: {item['game_id']}, Seed: {item['seed']}")
|
||||
|
||||
# Collect trajectories (will run group_size parallel games)
|
||||
scored_data_group, _ = await env.collect_trajectories(item)
|
||||
|
||||
if scored_data_group and scored_data_group["scores"]:
|
||||
avg_score = sum(scored_data_group["scores"]) / len(
|
||||
scored_data_group["scores"]
|
||||
)
|
||||
logger.info(
|
||||
f"Collected {len(scored_data_group['scores'])} trajectories with average score: {avg_score:.2f}"
|
||||
)
|
||||
|
||||
# Get game outcomes from buffer
|
||||
if env.game_outcomes_buffer:
|
||||
latest_outcomes = env.game_outcomes_buffer[-env.config.group_size :]
|
||||
for i, outcome in enumerate(latest_outcomes):
|
||||
logger.info(
|
||||
f" Game {i}: Score={outcome['score']:.2f}, "
|
||||
f"Winner={outcome['winner']}, "
|
||||
f"Turns={outcome['turns']}, "
|
||||
f"Centers={outcome['final_centers'].get(env.config.training_power, 0)}"
|
||||
)
|
||||
|
||||
episode_results.append(
|
||||
{
|
||||
"episode": episode_num + 1,
|
||||
"score": avg_score,
|
||||
"outcomes": latest_outcomes if env.game_outcomes_buffer else [],
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.error("Failed to collect trajectory")
|
||||
episode_results.append(
|
||||
{
|
||||
"episode": episode_num + 1,
|
||||
"score": 0.0,
|
||||
"outcomes": [],
|
||||
}
|
||||
)
|
||||
|
||||
# Print overall statistics
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("OVERALL RESULTS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total episodes: {num_episodes}")
|
||||
logger.info(f"Group size: {env.config.group_size} games per episode")
|
||||
logger.info(f"Training power: {env.config.training_power}")
|
||||
|
||||
# Calculate statistics
|
||||
if episode_results:
|
||||
avg_score = sum(ep["score"] for ep in episode_results) / len(
|
||||
episode_results
|
||||
)
|
||||
logger.info(f"\nAverage trajectory score: {avg_score:.2f}")
|
||||
|
||||
# Count wins
|
||||
total_games = 0
|
||||
wins = 0
|
||||
for ep in episode_results:
|
||||
for outcome in ep["outcomes"]:
|
||||
total_games += 1
|
||||
if outcome["winner"] == env.config.training_power:
|
||||
wins += 1
|
||||
|
||||
if total_games > 0:
|
||||
logger.info(
|
||||
f"Win rate: {wins}/{total_games} ({100*wins/total_games:.1f}%)"
|
||||
)
|
||||
|
||||
# Average supply centers
|
||||
total_centers = sum(
|
||||
outcome["final_centers"].get(env.config.training_power, 0)
|
||||
for ep in episode_results
|
||||
for outcome in ep["outcomes"]
|
||||
)
|
||||
avg_centers = total_centers / total_games
|
||||
logger.info(f"Average final supply centers: {avg_centers:.1f}")
|
||||
|
||||
logger.info("\nTest completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"An error occurred during test: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Queue Manager for Diplomacy Environment
|
||||
|
||||
Manages request/response queues between AtroposClient proxies and the environment.
|
||||
Each game gets its own queue pair for isolation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuePair:
|
||||
"""A pair of queues for a single game."""
|
||||
|
||||
game_id: str
|
||||
request_queue: asyncio.Queue
|
||||
response_queue: asyncio.Queue
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyRequest:
|
||||
"""Request from proxy to environment for policy sampling."""
|
||||
|
||||
request_id: str
|
||||
game_id: str
|
||||
power: str
|
||||
phase: str
|
||||
prompt: str
|
||||
temperature: float
|
||||
trajectory: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyResponse:
|
||||
"""Response from environment back to proxy."""
|
||||
|
||||
request_id: str
|
||||
response: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""Manages queues for all parallel games."""
|
||||
|
||||
def __init__(self):
|
||||
self.queue_pairs: Dict[str, QueuePair] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create_game_queues(self, game_id: str) -> QueuePair:
|
||||
"""Create a new queue pair for a game."""
|
||||
async with self._lock:
|
||||
if game_id in self.queue_pairs:
|
||||
logger.warning(f"Queue pair already exists for game {game_id}")
|
||||
return self.queue_pairs[game_id]
|
||||
|
||||
queue_pair = QueuePair(
|
||||
game_id=game_id,
|
||||
request_queue=asyncio.Queue(),
|
||||
response_queue=asyncio.Queue(),
|
||||
)
|
||||
self.queue_pairs[game_id] = queue_pair
|
||||
logger.info(f"Created queue pair for game {game_id}")
|
||||
return queue_pair
|
||||
|
||||
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
|
||||
"""Get queue pair for a game."""
|
||||
return self.queue_pairs.get(game_id)
|
||||
|
||||
async def remove_game_queues(self, game_id: str):
|
||||
"""Remove queues for a completed game."""
|
||||
async with self._lock:
|
||||
if game_id in self.queue_pairs:
|
||||
del self.queue_pairs[game_id]
|
||||
logger.info(f"Removed queue pair for game {game_id}")
|
||||
|
||||
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
|
||||
"""Get all request queues for polling."""
|
||||
return {
|
||||
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
|
||||
}
|
||||
|
||||
async def put_request(self, game_id: str, request: PolicyRequest):
|
||||
"""Put a request on the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
await queue_pair.request_queue.put(request)
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
async def get_response(self, game_id: str) -> PolicyResponse:
|
||||
"""Get a response from the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
return await queue_pair.response_queue.get()
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
async def put_response(self, game_id: str, response: PolicyResponse):
|
||||
"""Put a response on the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
await queue_pair.response_queue.put(response)
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
|
||||
_queue_manager = QueueManager()
|
||||
|
||||
|
||||
def get_queue_manager() -> QueueManager:
|
||||
"""Get the global queue manager instance."""
|
||||
return _queue_manager
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# Requirements for minimal Diplomacy environment
|
||||
# Only includes dependencies specific to AI_Diplomacy that aren't in main project
|
||||
|
||||
# AI_Diplomacy specific dependencies
|
||||
tornado>=6.0
|
||||
ujson>=5.0
|
||||
websockets>=10.0
|
||||
coloredlogs>=15.0.1
|
||||
|
|
@ -49,7 +49,7 @@ async def main():
|
|||
# Configure server - using same model as blackjack example
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
model_name="gpt-4.1",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
|
|
@ -99,7 +99,6 @@ async def main():
|
|||
item = await env.get_next_item()
|
||||
challenge_name = item["challenge_name"]
|
||||
challenge_counts[challenge_name] += 1
|
||||
|
||||
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
|
||||
logger.info(f"Using game: {item}")
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class TextWorldChallengeRegistry:
|
|||
def __init__(self, seed: Optional[int] = None):
|
||||
self._challenges = self.CHALLENGES.copy()
|
||||
self.rng = random.Random(seed)
|
||||
|
||||
# Cache for all possible combinations
|
||||
self._all_combinations = None
|
||||
self._combination_index = 0
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 0699e2f5074379bb19924b7f64221ccfc8cef6ef
|
||||
Subproject commit 9e79fc84b614b69c04325854335e308cd61d731d
|
||||
Loading…
Add table
Add a link
Reference in a new issue