mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4fe67e698d
commit
46f0602227
13 changed files with 1317 additions and 4 deletions
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Queue Manager for Diplomacy Environment
|
||||
|
||||
Manages request/response queues between AtroposClient proxies and the environment.
|
||||
Each game gets its own queue pair for isolation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuePair:
|
||||
"""A pair of queues for a single game."""
|
||||
|
||||
game_id: str
|
||||
request_queue: asyncio.Queue
|
||||
response_queue: asyncio.Queue
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyRequest:
|
||||
"""Request from proxy to environment for policy sampling."""
|
||||
|
||||
request_id: str
|
||||
game_id: str
|
||||
power: str
|
||||
phase: str
|
||||
prompt: str
|
||||
temperature: float
|
||||
trajectory: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyResponse:
|
||||
"""Response from environment back to proxy."""
|
||||
|
||||
request_id: str
|
||||
response: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""Manages queues for all parallel games."""
|
||||
|
||||
def __init__(self):
|
||||
self.queue_pairs: Dict[str, QueuePair] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create_game_queues(self, game_id: str) -> QueuePair:
|
||||
"""Create a new queue pair for a game."""
|
||||
async with self._lock:
|
||||
if game_id in self.queue_pairs:
|
||||
logger.warning(f"Queue pair already exists for game {game_id}")
|
||||
return self.queue_pairs[game_id]
|
||||
|
||||
queue_pair = QueuePair(
|
||||
game_id=game_id,
|
||||
request_queue=asyncio.Queue(),
|
||||
response_queue=asyncio.Queue(),
|
||||
)
|
||||
self.queue_pairs[game_id] = queue_pair
|
||||
logger.info(f"Created queue pair for game {game_id}")
|
||||
return queue_pair
|
||||
|
||||
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
|
||||
"""Get queue pair for a game."""
|
||||
return self.queue_pairs.get(game_id)
|
||||
|
||||
async def remove_game_queues(self, game_id: str):
|
||||
"""Remove queues for a completed game."""
|
||||
async with self._lock:
|
||||
if game_id in self.queue_pairs:
|
||||
del self.queue_pairs[game_id]
|
||||
logger.info(f"Removed queue pair for game {game_id}")
|
||||
|
||||
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
|
||||
"""Get all request queues for polling."""
|
||||
return {
|
||||
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
|
||||
}
|
||||
|
||||
async def put_request(self, game_id: str, request: PolicyRequest):
|
||||
"""Put a request on the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
await queue_pair.request_queue.put(request)
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
async def get_response(self, game_id: str) -> PolicyResponse:
|
||||
"""Get a response from the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
return await queue_pair.response_queue.get()
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
async def put_response(self, game_id: str, response: PolicyResponse):
|
||||
"""Put a response on the appropriate queue."""
|
||||
queue_pair = self.get_queue_pair(game_id)
|
||||
if queue_pair:
|
||||
await queue_pair.response_queue.put(response)
|
||||
else:
|
||||
raise ValueError(f"No queue pair found for game {game_id}")
|
||||
|
||||
|
||||
_queue_manager = QueueManager()
|
||||
|
||||
|
||||
def get_queue_manager() -> QueueManager:
|
||||
"""Get the global queue manager instance."""
|
||||
return _queue_manager
|
||||
Loading…
Add table
Add a link
Reference in a new issue