mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
"""
|
|
Queue Manager for Diplomacy Environment
|
|
|
|
Manages request/response queues between AtroposClient proxies and the environment.
|
|
Each game gets its own queue pair for isolation.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class QueuePair:
|
|
"""A pair of queues for a single game."""
|
|
|
|
game_id: str
|
|
request_queue: asyncio.Queue
|
|
response_queue: asyncio.Queue
|
|
|
|
|
|
@dataclass
|
|
class PolicyRequest:
|
|
"""Request from proxy to environment for policy sampling."""
|
|
|
|
request_id: str
|
|
game_id: str
|
|
power: str
|
|
phase: str
|
|
prompt: str
|
|
temperature: float
|
|
trajectory: list
|
|
|
|
|
|
@dataclass
|
|
class PolicyResponse:
|
|
"""Response from environment back to proxy."""
|
|
|
|
request_id: str
|
|
response: str
|
|
metadata: dict
|
|
|
|
|
|
class QueueManager:
|
|
"""Manages queues for all parallel games."""
|
|
|
|
def __init__(self):
|
|
self.queue_pairs: Dict[str, QueuePair] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def create_game_queues(self, game_id: str) -> QueuePair:
|
|
"""Create a new queue pair for a game."""
|
|
async with self._lock:
|
|
if game_id in self.queue_pairs:
|
|
logger.warning(f"Queue pair already exists for game {game_id}")
|
|
return self.queue_pairs[game_id]
|
|
|
|
queue_pair = QueuePair(
|
|
game_id=game_id,
|
|
request_queue=asyncio.Queue(),
|
|
response_queue=asyncio.Queue(),
|
|
)
|
|
self.queue_pairs[game_id] = queue_pair
|
|
logger.info(f"Created queue pair for game {game_id}")
|
|
return queue_pair
|
|
|
|
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
|
|
"""Get queue pair for a game."""
|
|
return self.queue_pairs.get(game_id)
|
|
|
|
async def remove_game_queues(self, game_id: str):
|
|
"""Remove queues for a completed game."""
|
|
async with self._lock:
|
|
if game_id in self.queue_pairs:
|
|
del self.queue_pairs[game_id]
|
|
logger.info(f"Removed queue pair for game {game_id}")
|
|
|
|
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
|
|
"""Get all request queues for polling."""
|
|
return {
|
|
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
|
|
}
|
|
|
|
async def put_request(self, game_id: str, request: PolicyRequest):
|
|
"""Put a request on the appropriate queue."""
|
|
queue_pair = self.get_queue_pair(game_id)
|
|
if queue_pair:
|
|
await queue_pair.request_queue.put(request)
|
|
else:
|
|
raise ValueError(f"No queue pair found for game {game_id}")
|
|
|
|
async def get_response(self, game_id: str) -> PolicyResponse:
|
|
"""Get a response from the appropriate queue."""
|
|
queue_pair = self.get_queue_pair(game_id)
|
|
if queue_pair:
|
|
return await queue_pair.response_queue.get()
|
|
else:
|
|
raise ValueError(f"No queue pair found for game {game_id}")
|
|
|
|
async def put_response(self, game_id: str, response: PolicyResponse):
|
|
"""Put a response on the appropriate queue."""
|
|
queue_pair = self.get_queue_pair(game_id)
|
|
if queue_pair:
|
|
await queue_pair.response_queue.put(response)
|
|
else:
|
|
raise ValueError(f"No queue pair found for game {game_id}")
|
|
|
|
|
|
_queue_manager = QueueManager()
|
|
|
|
|
|
def get_queue_manager() -> QueueManager:
|
|
"""Get the global queue manager instance."""
|
|
return _queue_manager
|