From 46f06022278196003e3894c7874b1bebd8d908ca Mon Sep 17 00:00:00 2001 From: shannonsands Date: Tue, 12 Aug 2025 09:02:16 +1000 Subject: [PATCH] Diplomacy trainer env (#227) * minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitignore | 3 + .gitmodules | 3 + .../diplomacy_environment/AI_Diplomacy | 1 + .../diplomacy_environment/README.md | 54 ++ .../diplomacy_environment/__init__.py | 7 + .../atropos_client_minimal.py | 268 ++++++++ .../diplomacy_env_minimal.py | 636 ++++++++++++++++++ .../diplomacy_local_server.py | 218 ++++++ .../diplomacy_environment/queue_manager.py | 117 ++++ .../diplomacy_environment/requirements.txt | 8 + .../textworld_env/textworld_local_server.py | 3 +- .../textworld_env/textworld_registry.py | 1 - .../reasoning_gym_environment/reasoning-gym | 2 +- 13 files changed, 1317 insertions(+), 4 deletions(-) create mode 160000 environments/game_environments/diplomacy_environment/AI_Diplomacy create mode 100644 environments/game_environments/diplomacy_environment/README.md create mode 100644 environments/game_environments/diplomacy_environment/__init__.py create mode 100644 environments/game_environments/diplomacy_environment/atropos_client_minimal.py create mode 100644 environments/game_environments/diplomacy_environment/diplomacy_env_minimal.py create mode 100755 environments/game_environments/diplomacy_environment/diplomacy_local_server.py create mode 100644 environments/game_environments/diplomacy_environment/queue_manager.py create mode 100644 environments/game_environments/diplomacy_environment/requirements.txt diff --git a/.gitignore b/.gitignore index ce854945..e12cea2b 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,6 @@ environments/community/word_hunt/word_hunt_rollouts*.html *.z8 *.ni *.z5 + +# Diplomacy artefacts +environments/game_environments/diplomacy_environment/logs/ diff --git a/.gitmodules b/.gitmodules index ef77a7f2..a13aa697 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "environments/reasoning_gym_environment/reasoning-gym"] path = environments/reasoning_gym_environment/reasoning-gym url = https://github.com/open-thought/reasoning-gym +[submodule "environments/game_environments/diplomacy_environment/AI_Diplomacy"] + path = environments/game_environments/diplomacy_environment/AI_Diplomacy + url = https://github.com/GoodStartLabs/AI_Diplomacy.git diff --git a/environments/game_environments/diplomacy_environment/AI_Diplomacy b/environments/game_environments/diplomacy_environment/AI_Diplomacy new file mode 160000 index 00000000..70d4ae2f --- /dev/null +++ b/environments/game_environments/diplomacy_environment/AI_Diplomacy @@ -0,0 +1 @@ +Subproject commit 70d4ae2fe029ed38bf823ce70f5d45a7bcc7afbc diff --git a/environments/game_environments/diplomacy_environment/README.md b/environments/game_environments/diplomacy_environment/README.md new file mode 100644 index 00000000..eed2074a --- /dev/null +++ b/environments/game_environments/diplomacy_environment/README.md @@ -0,0 +1,54 @@ +# Minimal Diplomacy Environment + +A simplified Diplomacy RL training environment for Atropos that integrates with AI_Diplomacy. + +## Overview + +This minimal implementation provides: +- Basic game integration via AI_Diplomacy submodule +- Parallel rollouts with configurable group_size +- LLM request interception through AtroposClient proxy +- Simple supply center based scoring +- No complex features (no GRPO, memory systems, or advanced scoring) + +## Architecture + +``` +Atropos Policy Server + ↓ +AtroposClientMinimal (proxy) + ↓ +AI_Diplomacy Game Engine + ↓ +Game Execution +``` + +## Quick Start + +1. Install dependencies: +```bash +pip install -r requirements.txt +cd AI_Diplomacy +pip install -e . +``` + +2. Start your Atropos policy server on port 8000 + +3. Run the environment: +```bash +python diplomacy_env_minimal.py serve +``` + +## Configuration + +Key settings in `DiplomacyEnvMinimalConfig`: +- `max_game_turns`: Number of game turns (default: 10) +- `training_power`: Which power the RL agent controls (default: "FRANCE") +- `group_size`: Number of parallel games per trajectory (default: 4) + +## How It Works + +1. **Parallel Rollouts**: Each training step runs `group_size` games with the same initial seed +2. **LLM Interception**: AtroposClientMinimal intercepts all LLM calls from AI_Diplomacy +3. **Trajectory Collection**: Game interactions are collected and scored +4. **Best Selection**: The highest scoring trajectory is returned for training diff --git a/environments/game_environments/diplomacy_environment/__init__.py b/environments/game_environments/diplomacy_environment/__init__.py new file mode 100644 index 00000000..6b43211b --- /dev/null +++ b/environments/game_environments/diplomacy_environment/__init__.py @@ -0,0 +1,7 @@ +""" +Minimal Diplomacy Environment for Atropos RL Training +""" + +from .diplomacy_env_minimal import DiplomacyEnvMinimal, DiplomacyEnvMinimalConfig + +__all__ = ["DiplomacyEnvMinimal", "DiplomacyEnvMinimalConfig"] diff --git a/environments/game_environments/diplomacy_environment/atropos_client_minimal.py b/environments/game_environments/diplomacy_environment/atropos_client_minimal.py new file mode 100644 index 00000000..65f16e58 --- /dev/null +++ b/environments/game_environments/diplomacy_environment/atropos_client_minimal.py @@ -0,0 +1,268 @@ +""" + +This is a queue-based proxy that: +- Intercepts LLM requests from AI_Diplomacy +- Puts them on a queue for the environment to process +- Waits for responses from the environment +- Returns responses to AI_Diplomacy +""" + +import asyncio +import contextvars +import json +import logging +import os +import sys +import uuid +from typing import Dict, List, Optional + +from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import ( + clients, +) +from environments.game_environments.diplomacy_environment.queue_manager import ( + PolicyRequest, + QueueManager, + get_queue_manager, +) + +sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy")) + +from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402 + BaseModelClient, +) + +logger = logging.getLogger(__name__) + +current_game_context = contextvars.ContextVar("current_game_id", default=None) +_game_interactions = {} + + +class AtroposClientMinimal(BaseModelClient): + """ + Queue-based proxy client that forwards LLM requests through queues. + """ + + def __init__( + self, + model_name: str, + queue_manager: Optional[QueueManager] = None, + ): + super().__init__(model_name) + self.game_id = current_game_context.get() + if not self.game_id: + raise ValueError("AtroposClientMinimal created without game context set") + + self.queue_manager = queue_manager or get_queue_manager() + + self.interactions: List[Dict] = [] + self.current_power: Optional[str] = None + self.current_phase: Optional[str] = None + + logger.info( + f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}" + ) + + async def generate_response(self, prompt: str, temperature: float = 0.0) -> str: + """ + Put request on queue and wait for response from environment. + This is the main method AI_Diplomacy calls for all LLM interactions. + """ + task_type = self._infer_task_type(prompt) + power = self._extract_power(prompt) + phase = self._extract_phase(prompt) + + if power: + self.current_power = power + if phase: + self.current_phase = phase + + logger.debug(f"Generating response for {self.current_power}: {task_type}") + + try: + request_id = str(uuid.uuid4()) + request = PolicyRequest( + request_id=request_id, + game_id=self.game_id, + power=self.current_power or "UNKNOWN", + phase=self.current_phase or "UNKNOWN", + prompt=prompt, + temperature=temperature, + trajectory=self.interactions.copy(), + ) + + await self.queue_manager.put_request(self.game_id, request) + logger.debug(f"Put request {request_id} on queue for game {self.game_id}") + + response = await self.queue_manager.get_response(self.game_id) + + if response.request_id != request_id: + logger.warning( + f"Response ID mismatch: expected {request_id}, got {response.request_id}" + ) + + response_text = response.response + + # Track interaction + interaction = { + "power": self.current_power, + "phase": self.current_phase, + "task_type": task_type, + "prompt": prompt, + "response": response_text, + "metadata": response.metadata, # Store any additional info from environment + } + self.interactions.append(interaction) + + if self.game_id not in _game_interactions: + _game_interactions[self.game_id] = [] + _game_interactions[self.game_id].append(interaction) + + return response_text + + except asyncio.TimeoutError: + logger.error("Timeout waiting for response from environment") + return self._generate_fallback_response(prompt) + except Exception as e: + logger.error(f"Error generating response: {e}") + return self._generate_fallback_response(prompt) + + def _infer_task_type(self, prompt: str) -> str: + """Infer the type of task from the prompt.""" + prompt_lower = prompt.lower() + + if "orders" in prompt_lower or "submit" in prompt_lower: + return "orders" + elif "message" in prompt_lower or "negotiate" in prompt_lower: + return "negotiation" + elif "plan" in prompt_lower or "strategy" in prompt_lower: + return "planning" + else: + return "general" + + def _extract_power(self, prompt: str) -> Optional[str]: + """Extract power name from prompt if mentioned.""" + for power in [ + "AUSTRIA", + "ENGLAND", + "FRANCE", + "GERMANY", + "ITALY", + "RUSSIA", + "TURKEY", + ]: + if power in prompt.upper(): + return power + return None + + def _extract_phase(self, prompt: str) -> Optional[str]: + """Extract game phase from prompt if mentioned.""" + import re + + phase_match = re.search(r"[SF]\d{4}[MRB]", prompt) + if phase_match: + return phase_match.group() + + verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt) + if verbose_match: + return verbose_match.group() + + return None + + def _generate_fallback_response(self, prompt: str) -> str: + """Generate a simple fallback response if there's an issue.""" + task_type = self._infer_task_type(prompt) + + if task_type == "orders": + return json.dumps( + { + "orders": {}, + "explanations": {"general": "Fallback - no server connected"}, + } + ) + elif task_type == "negotiation": + return json.dumps( + { + "messages": [], + "explanations": {"general": "Fallback - no server connected"}, + } + ) + else: + return "Fallback response - server not available" + + def get_interactions(self) -> List[Dict]: + """Get all tracked interactions for trajectory collection.""" + return self.interactions + + def clear_interactions(self): + """Clear tracked interactions for a new game.""" + self.interactions = [] + self.current_power = None + self.current_phase = None + + +def get_game_interactions(game_id: str) -> List[Dict]: + """Get all interactions for a specific game.""" + return _game_interactions.get(game_id, []) + + +def clear_game_interactions(game_id: str): + """Clear interactions for a specific game.""" + if game_id in _game_interactions: + del _game_interactions[game_id] + + +def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None): + """ + Register AtroposClientMinimal with AI_Diplomacy's model loading system globally. + This should be called ONCE during environment setup. + + Args: + queue_manager: Optional queue manager (uses global if not provided) + """ + + if hasattr(clients, "_atropos_registered"): + logger.info("AtroposClientMinimal already registered globally") + return + + clients._original_load_model_client = clients.load_model_client + clients._atropos_queue_manager = queue_manager or get_queue_manager() + + def load_model_client_with_atropos( + model_id: str, prompts_dir: Optional[str] = None + ) -> BaseModelClient: + if model_id.startswith("atropos-"): + logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}") + return AtroposClientMinimal(model_id, clients._atropos_queue_manager) + else: + logger.info(f"Falling back to original loader for {model_id}") + return clients._original_load_model_client(model_id, prompts_dir) + + clients.load_model_client = load_model_client_with_atropos + clients._atropos_registered = True + + logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy") + + +if __name__ == "__main__": + + async def test_client(): + client = AtroposClientMinimal( + "atropos-test", + {"base_url": "http://localhost:8000", "model_name": "test-model"}, + ) + + test_prompts = [ + "You are FRANCE. What are your orders for Spring 1901?", + "Send a message to ENGLAND about cooperation.", + "What is your strategic plan?", + ] + + for prompt in test_prompts: + print(f"\nPrompt: {prompt[:50]}...") + response = await client.generate_response(prompt) + print(f"Response: {response[:100]}...") + + print(f"\nTracked {len(client.get_interactions())} interactions") + await client.close() + + asyncio.run(test_client()) diff --git a/environments/game_environments/diplomacy_environment/diplomacy_env_minimal.py b/environments/game_environments/diplomacy_environment/diplomacy_env_minimal.py new file mode 100644 index 00000000..5ba83f7f --- /dev/null +++ b/environments/game_environments/diplomacy_environment/diplomacy_env_minimal.py @@ -0,0 +1,636 @@ +""" +Minimal Diplomacy Environment for Atropos + +A simplified implementation focusing on: +- Basic game integration with AI_Diplomacy +- Parallel rollouts with group_size +- LLM proxy interception via AtroposClient +- Simple supply center based scoring +""" + +import asyncio +import json +import logging +import os +import random +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, + ScoredDataItem, +) +from atroposlib.type_definitions import Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from environments.game_environments.diplomacy_environment.atropos_client_minimal import ( + clear_game_interactions, + current_game_context, + get_game_interactions, + register_atropos_models_globally, +) +from environments.game_environments.diplomacy_environment.queue_manager import ( + PolicyRequest, + PolicyResponse, + get_queue_manager, +) + +sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy")) +import lm_game # noqa: E402 + +logger = logging.getLogger(__name__) + +POWERS = ["AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"] +STARTING_SUPPLY_CENTERS = { + "AUSTRIA": 3, + "ENGLAND": 3, + "FRANCE": 3, + "GERMANY": 3, + "ITALY": 3, + "RUSSIA": 4, + "TURKEY": 3, +} + + +class DiplomacyEnvMinimalConfig(BaseEnvConfig): + """Configuration for the minimal Diplomacy environment.""" + + env_name: str = "diplomacy_minimal" + + # Game settings (minimal implementation, no randomisation stuff) + max_game_turns: int = 10 + training_power: str = "FRANCE" + total_steps: int = 10 + + # Scoring + supply_center_weight: float = 1.0 + survival_bonus: float = 0.1 + win_bonus: float = 5.0 + + # Process management + diplomacy_server_port: int = 8432 + start_diplomacy_server: bool = True + + # Logging + save_game_logs: bool = True + game_logs_dir: str = "./game_logs" + + # Evaluation + eval_episodes: int = 10 + + # Opponent models (None = use SGLang endpoints from server_configs) + opponent_models: Optional[List[str]] = None + + +class DiplomacyEnvMinimal(BaseEnv): + name = "diplomacy_minimal" + env_config_cls = DiplomacyEnvMinimalConfig + + def __init__( + self, + config: DiplomacyEnvMinimalConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config: DiplomacyEnvMinimalConfig = config + self.game_server_process: Optional[subprocess.Popen] = None + self.game_outcomes_buffer: List[Dict] = [] + self.eval_metrics_custom: List[Tuple[str, float]] = [] + + self.queue_manager = get_queue_manager() + self.active_games: Dict[str, Dict] = {} + + if config.save_game_logs: + Path(config.game_logs_dir).mkdir(exist_ok=True) + + self.system_prompt = ( + f"You are playing Diplomacy as {config.training_power}. " + "Analyze the game state and respond with your strategy and orders." + ) + + if config.opponent_models: + self.opponent_models = config.opponent_models + else: + self.opponent_models = [] + for server_config in server_configs: + model_spec = f"openai:{server_config.model_name}@{server_config.base_url}#{server_config.api_key}" + self.opponent_models.append(model_spec) + + @classmethod + def config_init(cls) -> Tuple[DiplomacyEnvMinimalConfig, List[APIServerConfig]]: + """Initialize default configuration.""" + env_config = DiplomacyEnvMinimalConfig( + tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", + group_size=4, + use_wandb=True, + rollout_server_url="http://localhost:8000", + max_token_length=4096, + wandb_name=cls.name, + steps_per_eval=20, + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=128, + ), + APIServerConfig( + model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", + base_url="http://localhost:9005/v1", + api_key="x", + num_requests_for_eval=128, + ), + APIServerConfig( + model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", + base_url="http://localhost:9006/v1", + api_key="x", + num_requests_for_eval=128, + ), + APIServerConfig( + model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", + base_url="http://localhost:9007/v1", + api_key="x", + num_requests_for_eval=128, + ), + ] + return env_config, server_configs + + async def setup(self): + """Set up the environment.""" + logger.info(f"Setting up {self.name} environment") + + register_atropos_models_globally(self.queue_manager) + + if self.config.start_diplomacy_server: + await self._start_diplomacy_server() + + asyncio.create_task(self._poll_request_queues()) + + async def _poll_request_queues(self): + """Poll request queues and handle policy requests.""" + while True: + try: + for game_id in list(self.active_games.keys()): + queue_pair = self.queue_manager.get_queue_pair(game_id) + if not queue_pair: + continue + + try: + request = queue_pair.request_queue.get_nowait() + asyncio.create_task(self._handle_policy_request(request)) + except asyncio.QueueEmpty: + pass + + await asyncio.sleep(0.01) + + except Exception as e: + logger.error(f"Error in queue polling: {e}") + await asyncio.sleep(1) + + async def _handle_policy_request(self, request: PolicyRequest): + """Handle a single policy request by sampling from SGLang.""" + try: + logger.info( + f"Handling request {request.request_id} for {request.power} in game {request.game_id}" + ) + + messages = [{"role": "system", "content": self.system_prompt}] + + for interaction in request.trajectory: + messages.append({"role": "user", "content": interaction["prompt"]}) + messages.append( + {"role": "assistant", "content": interaction["response"]} + ) + + messages.append({"role": "user", "content": request.prompt}) + + async with self.server.dedicated_server() as server: + response = await server.chat_completion( + messages=messages, + n=1, + temperature=request.temperature, + max_tokens=2000, + ) + + response_text = response.choices[0].message.content.strip() + + policy_response = PolicyResponse( + request_id=request.request_id, + response=response_text, + metadata={ + "power": request.power, + "phase": request.phase, + }, + ) + + await self.queue_manager.put_response(request.game_id, policy_response) + logger.debug(f"Sent response for request {request.request_id}") + + except Exception as e: + logger.error(f"Error handling policy request: {e}") + error_response = PolicyResponse( + request_id=request.request_id, + response="Error: Failed to generate response", + metadata={"error": str(e)}, + ) + await self.queue_manager.put_response(request.game_id, error_response) + + async def _start_diplomacy_server(self): + """Start the AI_Diplomacy game server.""" + try: + logger.info( + f"Starting Diplomacy server on port {self.config.diplomacy_server_port}" + ) + self.game_server_process = subprocess.Popen( + [ + sys.executable, + "-m", + "diplomacy.server.run", + "--port", + str(self.config.diplomacy_server_port), + ], + cwd=os.path.join(os.path.dirname(__file__), "AI_Diplomacy"), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # Give server time to start + await asyncio.sleep(3) + logger.info( + f"Diplomacy server started at http://localhost:{self.config.diplomacy_server_port}" + ) + except Exception as e: + logger.error(f"Failed to start Diplomacy server: {e}") + + async def collect_trajectories( + self, item: Item + ) -> Tuple[ScoredDataGroup, List[Item]]: + """ + Run parallel Diplomacy games and collect all trajectories. + + This implements the key RL training pattern: + 1. Run group_size parallel games with the same seed + 2. Each game explores different action sequences + 3. Score each trajectory based on game outcome + 4. Return all trajectories as a ScoredDataGroup for training + """ + logger.warning( + f"[DiplomacyEnvMinimal] collect_trajectories called with item: {item}" + ) + base_game_id = item.get("game_id", f"game-{int(time.time())}") + seed = item.get("seed", random.randint(0, 1_000_000)) + + logger.info( + f"Starting {self.config.group_size} parallel games with seed {seed}" + ) + + # Run parallel games w/ same seed + game_tasks = [] + for i in range(self.config.group_size): + game_id = f"{base_game_id}-{i}" + task = self._run_single_game(game_id, seed, trajectory_id=i) + game_tasks.append(task) + + results = await asyncio.gather(*game_tasks, return_exceptions=True) + + scored_items = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Game {i} failed: {result}") + continue + + if result and result[0]: + scored_items.append(result[0]) + + logger.warning( + f"[DiplomacyEnvMinimal] Collected {len(scored_items)} scored items" + ) + if not scored_items: + logger.error("No valid trajectories collected") + return ( + ScoredDataGroup( + tokens=[], + masks=[], + scores=[], + messages=[], + advantages=None, + ref_logprobs=None, + group_overrides={}, + overrides=None, + images=None, + ), + [], + ) + + sdg = ScoredDataGroup( + tokens=[], + masks=[], + scores=[], + messages=[], + advantages=None, + ref_logprobs=None, + group_overrides={}, + overrides=None, + images=None, + ) + + for scored_item in scored_items: + sdg["tokens"].append(scored_item["tokens"]) + sdg["masks"].append(scored_item["masks"]) + sdg["scores"].append(scored_item["scores"]) + if self.config.include_messages and scored_item.get("messages"): + sdg["messages"].append(scored_item["messages"]) + + logger.info(f"Collected {len(scored_items)} trajectories") + logger.warning( + f"[DiplomacyEnvMinimal] Returning ScoredDataGroup with {len(sdg['tokens'])} " + f"tokens, {len(sdg['scores'])} scores" + ) + logger.warning( + f"[DiplomacyEnvMinimal] First few scores: {sdg['scores'][:5] if sdg['scores'] else 'None'}" + ) + + for i in range(self.config.group_size): + game_id = f"{base_game_id}-{i}" + if game_id in self.active_games: + del self.active_games[game_id] + try: + await self.queue_manager.remove_game_queues(game_id) + except Exception as e: + logger.debug(f"Error cleaning up queues for {game_id}: {e}") + + return sdg, [] + + async def _run_single_game( + self, game_id: str, seed: int, trajectory_id: int + ) -> Tuple[Optional[ScoredDataItem], None]: + """ + Run a single Diplomacy game and return scored trajectory. + """ + try: + queue_pair = await self.queue_manager.create_game_queues(game_id) + + self.active_games[game_id] = { + "queue_pair": queue_pair, + "start_time": time.time(), + "interactions": [], + } + + token = current_game_context.set(game_id) + + try: + game_result = await self._run_diplomacy_game( + game_id, seed, trajectory_id + ) + finally: + current_game_context.reset(token) + + if not game_result: + logger.error(f"Game {game_id} failed to complete") + return None, None + + score = self._calculate_score(game_result, self.config.training_power) + + interactions = get_game_interactions(game_id) + + training_interactions = [ + i for i in interactions if i.get("power") == self.config.training_power + ] + + if training_interactions: + messages = [{"role": "system", "content": self.system_prompt}] + + for interaction in training_interactions: + messages.append({"role": "user", "content": interaction["prompt"]}) + messages.append( + {"role": "assistant", "content": interaction["response"]} + ) + + logger.info( + f"Collected {len(training_interactions)} interactions for {self.config.training_power}" + ) + else: + logger.warning( + f"No interactions found for {self.config.training_power} in game {game_id}" + ) + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": f"Playing Diplomacy game {game_id}"}, + { + "role": "assistant", + "content": f"Game completed with score {score:.2f}", + }, + ] + + clear_game_interactions(game_id) + + tokenization_result = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=messages, + train_on_all_assistant_turns=True, + ) + + scored_data_item = ScoredDataItem( + messages=messages if self.config.include_messages else None, + tokens=tokenization_result["tokens"], + masks=tokenization_result["masks"], + scores=score, + ) + + self.game_outcomes_buffer.append( + { + "game_id": game_id, + "score": score, + "winner": game_result.get("winner"), + "turns": game_result.get("turns_played", 0), + "final_centers": game_result.get("final_centers", {}), + } + ) + + return scored_data_item, None + + except Exception as e: + logger.error(f"Error in game {game_id}: {e}", exc_info=True) + return None, None + + async def _run_diplomacy_game( + self, game_id: str, seed: int, trajectory_id: int = 0 + ) -> Optional[Dict]: + """Run a Diplomacy game using AI_Diplomacy's lm_game module.""" + + game_output_path = os.path.join(self.config.game_logs_dir, f"{game_id}.json") + + models = [] + opponent_idx = 0 + + for power in POWERS: + if power == self.config.training_power: + models.append("atropos-training-policy") + else: + models.append( + self.opponent_models[opponent_idx % len(self.opponent_models)] + ) + opponent_idx += 1 + + original_argv = sys.argv + + try: + sys.argv = [ + "lm_game.py", + "--models", + ",".join(models), + "--max_year", + str(1900 + self.config.max_game_turns), + "--output", + game_output_path, + "--seed", + str(seed + trajectory_id), + "--num_negotiation_rounds", + "0", + ] + + await lm_game.main() + + actual_game_file = os.path.join(game_output_path, "lmvsgame.json") + if os.path.exists(actual_game_file): + with open(actual_game_file, "r") as f: + saved_game = json.load(f) + + phases = saved_game.get("phases", []) + last_phase = phases[-1] if phases else {} + + result = { + "winner": saved_game.get("winner"), + "turns_played": len(phases), + "final_centers": {}, + } + + for power in POWERS: + centers = ( + last_phase.get("state", {}).get("centers", {}).get(power, []) + ) + result["final_centers"][power] = len(centers) + + return result + else: + logger.error(f"Game output not found: {actual_game_file}") + return None + + finally: + sys.argv = original_argv + + def _calculate_score(self, game_result: Dict, power: str) -> float: + """Calculate score for a power based on game outcome.""" + start_centers = STARTING_SUPPLY_CENTERS[power] + end_centers = game_result["final_centers"].get(power, 0) + center_score = (end_centers - start_centers) * self.config.supply_center_weight + + survival_score = self.config.survival_bonus * game_result["turns_played"] + + win_score = self.config.win_bonus if game_result["winner"] == power else 0.0 + + total_score = center_score + survival_score + win_score + + logger.info( + f"{power} score: centers={center_score:.2f}, " + f"survival={survival_score:.2f}, win={win_score:.2f}, " + f"total={total_score:.2f}" + ) + + return total_score + + async def get_next_item(self) -> Item: + """Generate configuration for the next game.""" + return { + "game_id": f"game-{int(time.time())}-{random.randint(1000, 9999)}", + "seed": random.randint(0, 1_000_000), + } + + async def evaluate(self, *args, **kwargs): + """Run evaluation games.""" + logger.info(f"Starting evaluation with {self.config.eval_episodes} episodes") + + eval_scores = [] + wins = 0 + + for i in range(self.config.eval_episodes): + item = await self.get_next_item() + item["is_eval"] = True + + scored_data_group, _ = await self.collect_trajectories(item) + if scored_data_group and scored_data_group["scores"]: + avg_score = sum(scored_data_group["scores"]) / len( + scored_data_group["scores"] + ) + eval_scores.append(avg_score) + + if ( + self.game_outcomes_buffer + and self.game_outcomes_buffer[-1]["winner"] + == self.config.training_power + ): + wins += 1 + + if eval_scores: + avg_score = sum(eval_scores) / len(eval_scores) + win_rate = wins / self.config.eval_episodes + + self.eval_metrics_custom = [ + (f"{self.name}_eval/avg_score", avg_score), + (f"{self.name}_eval/win_rate", win_rate), + (f"{self.name}_eval/num_completed", len(eval_scores)), + ] + + logger.info( + f"Evaluation completed: avg_score={avg_score:.2f}, win_rate={win_rate:.2%}" + ) + + async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None): + """Log metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + + if self.game_outcomes_buffer: + total_games = len(self.game_outcomes_buffer) + avg_score = sum(g["score"] for g in self.game_outcomes_buffer) / total_games + wins = sum( + 1 + for g in self.game_outcomes_buffer + if g["winner"] == self.config.training_power + ) + win_rate = wins / total_games + avg_turns = sum(g["turns"] for g in self.game_outcomes_buffer) / total_games + + wandb_metrics.update( + { + f"{self.name}_train/avg_score": avg_score, + f"{self.name}_train/win_rate": win_rate, + f"{self.name}_train/avg_turns": avg_turns, + f"{self.name}_train/num_games": total_games, + } + ) + + self.game_outcomes_buffer = [] + + for key, value in self.eval_metrics_custom: + wandb_metrics[key] = value + self.eval_metrics_custom = [] + + await super().wandb_log(wandb_metrics) + + def __del__(self): + """Clean up server process on exit.""" + if self.game_server_process: + self.game_server_process.terminate() + self.game_server_process.wait() + + +if __name__ == "__main__": + DiplomacyEnvMinimal.cli() diff --git a/environments/game_environments/diplomacy_environment/diplomacy_local_server.py b/environments/game_environments/diplomacy_environment/diplomacy_local_server.py new file mode 100755 index 00000000..c1e40a9c --- /dev/null +++ b/environments/game_environments/diplomacy_environment/diplomacy_local_server.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +""" +Local test server for the minimal Diplomacy environment. + +This script runs the full AI_Diplomacy game with real OpenAI models +to test the AtroposClient proxy integration. +""" + +import asyncio +import logging +import os + +from dotenv import load_dotenv + +from atroposlib.envs.base import APIServerConfig, EvalHandlingEnum +from environments.game_environments.diplomacy_environment.diplomacy_env_minimal import ( + DiplomacyEnvMinimal, + DiplomacyEnvMinimalConfig, +) + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """Run Diplomacy games for testing the minimal environment.""" + logger.info("Starting Diplomacy minimal environment local test runner") + + # Check for OpenRouter API key + if not os.getenv("OPENROUTER_API_KEY"): + logger.error( + "OPENROUTER_API_KEY not found. Please set it in your environment or .env file" + ) + return + + # Configure environment - using OpenRouter model + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + openrouter_model = ( + f"openai:openai/gpt-oss-120b@https://openrouter.ai/api/v1#{openrouter_api_key}" + ) + + # Create list of opponent models (6 powers besides training power) + opponent_models = [openrouter_model] * 6 + + env_config = DiplomacyEnvMinimalConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=2, # Run 2 parallel games + use_wandb=False, + wandb_name="diplomacy_minimal_local_test", + max_num_workers=1, + rollout_server_url="http://localhost:8000", + total_steps=1, + batch_size=1, + steps_per_eval=0, + max_token_length=4096, + inference_weight=1.0, + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.NONE, + eval_limit_ratio=0.0, + max_game_turns=5, # Short games for testing + training_power="FRANCE", # Which power we're training + include_messages=True, # Include messages for debugging + eval_episodes=0, + start_diplomacy_server=True, # Let the env start the server + save_game_logs=True, + game_logs_dir="./test_game_logs", + opponent_models=opponent_models, # Use OpenRouter for all opponents + ) + + # Configure server - using 4 servers to match SLURM setup + # For local testing, we'll simulate this with the same OpenRouter endpoint + server_configs = [ + APIServerConfig( + model_name="openai/gpt-oss-120b", # Using the OpenRouter model + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + num_requests_for_eval=0, + ), + APIServerConfig( + model_name="openai/gpt-oss-120b", + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + num_requests_for_eval=0, + ), + APIServerConfig( + model_name="openai/gpt-oss-120b", + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + num_requests_for_eval=0, + ), + APIServerConfig( + model_name="openai/gpt-oss-120b", + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + num_requests_for_eval=0, + ), + ] + + logger.info("Using OpenRouter openai/gpt-oss-120b for Diplomacy test") + logger.debug(f"Env Config: {env_config}") + logger.debug(f"Server Configs: {server_configs}") + + try: + env = DiplomacyEnvMinimal( + config=env_config, + server_configs=server_configs, + slurm=False, + testing=False, + ) + except Exception as e: + logger.exception(f"Failed to initialize DiplomacyEnvMinimal: {e}") + return + + logger.info("Running test games") + try: + await env.setup() + + # Get number of episodes from command line or default + import sys + + num_episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 3 + + # Track statistics + episode_results = [] + + for episode_num in range(num_episodes): + logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====") + + item = await env.get_next_item() + logger.info(f"Game ID: {item['game_id']}, Seed: {item['seed']}") + + # Collect trajectories (will run group_size parallel games) + scored_data_group, _ = await env.collect_trajectories(item) + + if scored_data_group and scored_data_group["scores"]: + avg_score = sum(scored_data_group["scores"]) / len( + scored_data_group["scores"] + ) + logger.info( + f"Collected {len(scored_data_group['scores'])} trajectories with average score: {avg_score:.2f}" + ) + + # Get game outcomes from buffer + if env.game_outcomes_buffer: + latest_outcomes = env.game_outcomes_buffer[-env.config.group_size :] + for i, outcome in enumerate(latest_outcomes): + logger.info( + f" Game {i}: Score={outcome['score']:.2f}, " + f"Winner={outcome['winner']}, " + f"Turns={outcome['turns']}, " + f"Centers={outcome['final_centers'].get(env.config.training_power, 0)}" + ) + + episode_results.append( + { + "episode": episode_num + 1, + "score": avg_score, + "outcomes": latest_outcomes if env.game_outcomes_buffer else [], + } + ) + else: + logger.error("Failed to collect trajectory") + episode_results.append( + { + "episode": episode_num + 1, + "score": 0.0, + "outcomes": [], + } + ) + + # Print overall statistics + logger.info("\n" + "=" * 60) + logger.info("OVERALL RESULTS SUMMARY") + logger.info("=" * 60) + logger.info(f"Total episodes: {num_episodes}") + logger.info(f"Group size: {env.config.group_size} games per episode") + logger.info(f"Training power: {env.config.training_power}") + + # Calculate statistics + if episode_results: + avg_score = sum(ep["score"] for ep in episode_results) / len( + episode_results + ) + logger.info(f"\nAverage trajectory score: {avg_score:.2f}") + + # Count wins + total_games = 0 + wins = 0 + for ep in episode_results: + for outcome in ep["outcomes"]: + total_games += 1 + if outcome["winner"] == env.config.training_power: + wins += 1 + + if total_games > 0: + logger.info( + f"Win rate: {wins}/{total_games} ({100*wins/total_games:.1f}%)" + ) + + # Average supply centers + total_centers = sum( + outcome["final_centers"].get(env.config.training_power, 0) + for ep in episode_results + for outcome in ep["outcomes"] + ) + avg_centers = total_centers / total_games + logger.info(f"Average final supply centers: {avg_centers:.1f}") + + logger.info("\nTest completed successfully!") + + except Exception as e: + logger.exception(f"An error occurred during test: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/environments/game_environments/diplomacy_environment/queue_manager.py b/environments/game_environments/diplomacy_environment/queue_manager.py new file mode 100644 index 00000000..d7744be4 --- /dev/null +++ b/environments/game_environments/diplomacy_environment/queue_manager.py @@ -0,0 +1,117 @@ +""" +Queue Manager for Diplomacy Environment + +Manages request/response queues between AtroposClient proxies and the environment. +Each game gets its own queue pair for isolation. +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class QueuePair: + """A pair of queues for a single game.""" + + game_id: str + request_queue: asyncio.Queue + response_queue: asyncio.Queue + + +@dataclass +class PolicyRequest: + """Request from proxy to environment for policy sampling.""" + + request_id: str + game_id: str + power: str + phase: str + prompt: str + temperature: float + trajectory: list + + +@dataclass +class PolicyResponse: + """Response from environment back to proxy.""" + + request_id: str + response: str + metadata: dict + + +class QueueManager: + """Manages queues for all parallel games.""" + + def __init__(self): + self.queue_pairs: Dict[str, QueuePair] = {} + self._lock = asyncio.Lock() + + async def create_game_queues(self, game_id: str) -> QueuePair: + """Create a new queue pair for a game.""" + async with self._lock: + if game_id in self.queue_pairs: + logger.warning(f"Queue pair already exists for game {game_id}") + return self.queue_pairs[game_id] + + queue_pair = QueuePair( + game_id=game_id, + request_queue=asyncio.Queue(), + response_queue=asyncio.Queue(), + ) + self.queue_pairs[game_id] = queue_pair + logger.info(f"Created queue pair for game {game_id}") + return queue_pair + + def get_queue_pair(self, game_id: str) -> Optional[QueuePair]: + """Get queue pair for a game.""" + return self.queue_pairs.get(game_id) + + async def remove_game_queues(self, game_id: str): + """Remove queues for a completed game.""" + async with self._lock: + if game_id in self.queue_pairs: + del self.queue_pairs[game_id] + logger.info(f"Removed queue pair for game {game_id}") + + def get_all_request_queues(self) -> Dict[str, asyncio.Queue]: + """Get all request queues for polling.""" + return { + game_id: pair.request_queue for game_id, pair in self.queue_pairs.items() + } + + async def put_request(self, game_id: str, request: PolicyRequest): + """Put a request on the appropriate queue.""" + queue_pair = self.get_queue_pair(game_id) + if queue_pair: + await queue_pair.request_queue.put(request) + else: + raise ValueError(f"No queue pair found for game {game_id}") + + async def get_response(self, game_id: str) -> PolicyResponse: + """Get a response from the appropriate queue.""" + queue_pair = self.get_queue_pair(game_id) + if queue_pair: + return await queue_pair.response_queue.get() + else: + raise ValueError(f"No queue pair found for game {game_id}") + + async def put_response(self, game_id: str, response: PolicyResponse): + """Put a response on the appropriate queue.""" + queue_pair = self.get_queue_pair(game_id) + if queue_pair: + await queue_pair.response_queue.put(response) + else: + raise ValueError(f"No queue pair found for game {game_id}") + + +_queue_manager = QueueManager() + + +def get_queue_manager() -> QueueManager: + """Get the global queue manager instance.""" + return _queue_manager diff --git a/environments/game_environments/diplomacy_environment/requirements.txt b/environments/game_environments/diplomacy_environment/requirements.txt new file mode 100644 index 00000000..3e8dcf42 --- /dev/null +++ b/environments/game_environments/diplomacy_environment/requirements.txt @@ -0,0 +1,8 @@ +# Requirements for minimal Diplomacy environment +# Only includes dependencies specific to AI_Diplomacy that aren't in main project + +# AI_Diplomacy specific dependencies +tornado>=6.0 +ujson>=5.0 +websockets>=10.0 +coloredlogs>=15.0.1 diff --git a/environments/game_environments/textworld_env/textworld_local_server.py b/environments/game_environments/textworld_env/textworld_local_server.py index 21d336b8..0cf509ea 100644 --- a/environments/game_environments/textworld_env/textworld_local_server.py +++ b/environments/game_environments/textworld_env/textworld_local_server.py @@ -49,7 +49,7 @@ async def main(): # Configure server - using same model as blackjack example server_configs = [ APIServerConfig( - model_name="gpt-4.1-nano", + model_name="gpt-4.1", base_url="https://api.openai.com/v1", api_key=os.getenv("OPENAI_API_KEY"), num_requests_for_eval=0, @@ -99,7 +99,6 @@ async def main(): item = await env.get_next_item() challenge_name = item["challenge_name"] challenge_counts[challenge_name] += 1 - logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====") logger.info(f"Using game: {item}") diff --git a/environments/game_environments/textworld_env/textworld_registry.py b/environments/game_environments/textworld_env/textworld_registry.py index 4eae38cf..6e43aebe 100644 --- a/environments/game_environments/textworld_env/textworld_registry.py +++ b/environments/game_environments/textworld_env/textworld_registry.py @@ -44,7 +44,6 @@ class TextWorldChallengeRegistry: def __init__(self, seed: Optional[int] = None): self._challenges = self.CHALLENGES.copy() self.rng = random.Random(seed) - # Cache for all possible combinations self._all_combinations = None self._combination_index = 0 diff --git a/environments/reasoning_gym_environment/reasoning-gym b/environments/reasoning_gym_environment/reasoning-gym index 0699e2f5..9e79fc84 160000 --- a/environments/reasoning_gym_environment/reasoning-gym +++ b/environments/reasoning_gym_environment/reasoning-gym @@ -1 +1 @@ -Subproject commit 0699e2f5074379bb19924b7f64221ccfc8cef6ef +Subproject commit 9e79fc84b614b69c04325854335e308cd61d731d