atropos/environments/game_environments/diplomacy_environment/queue_manager.py
shannonsands 46f0602227
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* Add diplomacy environment with AI_Diplomacy submodule

- Add diplomacy_env_minimal.py for diplomacy game environment
- Add atropos_client_minimal.py for client interface
- Add diplomacy_local_server.py for local game server
- Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy
- Fix import ordering and remove unused imports

* test file working, moving to cluster to test training

* updated gitignore

* removed logs

* minor fixes, training running now

* readded proxy reg and queue system

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* queue gameid bug, refactored

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleaned up configs & allowed for openrouter models to be easily used

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* linting

* Remove duplicate dependencies from diplomacy requirements.txt

Only keep AI_Diplomacy-specific dependencies that aren't already in the main project

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-12 09:02:16 +10:00

117 lines
3.5 KiB
Python

"""
Queue Manager for Diplomacy Environment
Manages request/response queues between AtroposClient proxies and the environment.
Each game gets its own queue pair for isolation.
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import Dict, Optional
logger = logging.getLogger(__name__)
@dataclass
class QueuePair:
"""A pair of queues for a single game."""
game_id: str
request_queue: asyncio.Queue
response_queue: asyncio.Queue
@dataclass
class PolicyRequest:
"""Request from proxy to environment for policy sampling."""
request_id: str
game_id: str
power: str
phase: str
prompt: str
temperature: float
trajectory: list
@dataclass
class PolicyResponse:
"""Response from environment back to proxy."""
request_id: str
response: str
metadata: dict
class QueueManager:
"""Manages queues for all parallel games."""
def __init__(self):
self.queue_pairs: Dict[str, QueuePair] = {}
self._lock = asyncio.Lock()
async def create_game_queues(self, game_id: str) -> QueuePair:
"""Create a new queue pair for a game."""
async with self._lock:
if game_id in self.queue_pairs:
logger.warning(f"Queue pair already exists for game {game_id}")
return self.queue_pairs[game_id]
queue_pair = QueuePair(
game_id=game_id,
request_queue=asyncio.Queue(),
response_queue=asyncio.Queue(),
)
self.queue_pairs[game_id] = queue_pair
logger.info(f"Created queue pair for game {game_id}")
return queue_pair
def get_queue_pair(self, game_id: str) -> Optional[QueuePair]:
"""Get queue pair for a game."""
return self.queue_pairs.get(game_id)
async def remove_game_queues(self, game_id: str):
"""Remove queues for a completed game."""
async with self._lock:
if game_id in self.queue_pairs:
del self.queue_pairs[game_id]
logger.info(f"Removed queue pair for game {game_id}")
def get_all_request_queues(self) -> Dict[str, asyncio.Queue]:
"""Get all request queues for polling."""
return {
game_id: pair.request_queue for game_id, pair in self.queue_pairs.items()
}
async def put_request(self, game_id: str, request: PolicyRequest):
"""Put a request on the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
await queue_pair.request_queue.put(request)
else:
raise ValueError(f"No queue pair found for game {game_id}")
async def get_response(self, game_id: str) -> PolicyResponse:
"""Get a response from the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
return await queue_pair.response_queue.get()
else:
raise ValueError(f"No queue pair found for game {game_id}")
async def put_response(self, game_id: str, response: PolicyResponse):
"""Put a response on the appropriate queue."""
queue_pair = self.get_queue_pair(game_id)
if queue_pair:
await queue_pair.response_queue.put(response)
else:
raise ValueError(f"No queue pair found for game {game_id}")
_queue_manager = QueueManager()
def get_queue_manager() -> QueueManager:
"""Get the global queue manager instance."""
return _queue_manager