atropos/environments/game_environments/diplomacy_environment/atropos_client_minimal.py
shannonsands 46f0602227
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* Add diplomacy environment with AI_Diplomacy submodule

- Add diplomacy_env_minimal.py for diplomacy game environment
- Add atropos_client_minimal.py for client interface
- Add diplomacy_local_server.py for local game server
- Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy
- Fix import ordering and remove unused imports

* test file working, moving to cluster to test training

* updated gitignore

* removed logs

* minor fixes, training running now

* readded proxy reg and queue system

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* queue gameid bug, refactored

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleaned up configs & allowed for openrouter models to be easily used

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* linting

* Remove duplicate dependencies from diplomacy requirements.txt

Only keep AI_Diplomacy-specific dependencies that aren't already in the main project

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-12 09:02:16 +10:00

268 lines
8.7 KiB
Python

"""
This is a queue-based proxy that:
- Intercepts LLM requests from AI_Diplomacy
- Puts them on a queue for the environment to process
- Waits for responses from the environment
- Returns responses to AI_Diplomacy
"""
import asyncio
import contextvars
import json
import logging
import os
import sys
import uuid
from typing import Dict, List, Optional
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import (
clients,
)
from environments.game_environments.diplomacy_environment.queue_manager import (
PolicyRequest,
QueueManager,
get_queue_manager,
)
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402
BaseModelClient,
)
logger = logging.getLogger(__name__)
current_game_context = contextvars.ContextVar("current_game_id", default=None)
_game_interactions = {}
class AtroposClientMinimal(BaseModelClient):
"""
Queue-based proxy client that forwards LLM requests through queues.
"""
def __init__(
self,
model_name: str,
queue_manager: Optional[QueueManager] = None,
):
super().__init__(model_name)
self.game_id = current_game_context.get()
if not self.game_id:
raise ValueError("AtroposClientMinimal created without game context set")
self.queue_manager = queue_manager or get_queue_manager()
self.interactions: List[Dict] = []
self.current_power: Optional[str] = None
self.current_phase: Optional[str] = None
logger.info(
f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}"
)
async def generate_response(self, prompt: str, temperature: float = 0.0) -> str:
"""
Put request on queue and wait for response from environment.
This is the main method AI_Diplomacy calls for all LLM interactions.
"""
task_type = self._infer_task_type(prompt)
power = self._extract_power(prompt)
phase = self._extract_phase(prompt)
if power:
self.current_power = power
if phase:
self.current_phase = phase
logger.debug(f"Generating response for {self.current_power}: {task_type}")
try:
request_id = str(uuid.uuid4())
request = PolicyRequest(
request_id=request_id,
game_id=self.game_id,
power=self.current_power or "UNKNOWN",
phase=self.current_phase or "UNKNOWN",
prompt=prompt,
temperature=temperature,
trajectory=self.interactions.copy(),
)
await self.queue_manager.put_request(self.game_id, request)
logger.debug(f"Put request {request_id} on queue for game {self.game_id}")
response = await self.queue_manager.get_response(self.game_id)
if response.request_id != request_id:
logger.warning(
f"Response ID mismatch: expected {request_id}, got {response.request_id}"
)
response_text = response.response
# Track interaction
interaction = {
"power": self.current_power,
"phase": self.current_phase,
"task_type": task_type,
"prompt": prompt,
"response": response_text,
"metadata": response.metadata, # Store any additional info from environment
}
self.interactions.append(interaction)
if self.game_id not in _game_interactions:
_game_interactions[self.game_id] = []
_game_interactions[self.game_id].append(interaction)
return response_text
except asyncio.TimeoutError:
logger.error("Timeout waiting for response from environment")
return self._generate_fallback_response(prompt)
except Exception as e:
logger.error(f"Error generating response: {e}")
return self._generate_fallback_response(prompt)
def _infer_task_type(self, prompt: str) -> str:
"""Infer the type of task from the prompt."""
prompt_lower = prompt.lower()
if "orders" in prompt_lower or "submit" in prompt_lower:
return "orders"
elif "message" in prompt_lower or "negotiate" in prompt_lower:
return "negotiation"
elif "plan" in prompt_lower or "strategy" in prompt_lower:
return "planning"
else:
return "general"
def _extract_power(self, prompt: str) -> Optional[str]:
"""Extract power name from prompt if mentioned."""
for power in [
"AUSTRIA",
"ENGLAND",
"FRANCE",
"GERMANY",
"ITALY",
"RUSSIA",
"TURKEY",
]:
if power in prompt.upper():
return power
return None
def _extract_phase(self, prompt: str) -> Optional[str]:
"""Extract game phase from prompt if mentioned."""
import re
phase_match = re.search(r"[SF]\d{4}[MRB]", prompt)
if phase_match:
return phase_match.group()
verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt)
if verbose_match:
return verbose_match.group()
return None
def _generate_fallback_response(self, prompt: str) -> str:
"""Generate a simple fallback response if there's an issue."""
task_type = self._infer_task_type(prompt)
if task_type == "orders":
return json.dumps(
{
"orders": {},
"explanations": {"general": "Fallback - no server connected"},
}
)
elif task_type == "negotiation":
return json.dumps(
{
"messages": [],
"explanations": {"general": "Fallback - no server connected"},
}
)
else:
return "Fallback response - server not available"
def get_interactions(self) -> List[Dict]:
"""Get all tracked interactions for trajectory collection."""
return self.interactions
def clear_interactions(self):
"""Clear tracked interactions for a new game."""
self.interactions = []
self.current_power = None
self.current_phase = None
def get_game_interactions(game_id: str) -> List[Dict]:
"""Get all interactions for a specific game."""
return _game_interactions.get(game_id, [])
def clear_game_interactions(game_id: str):
"""Clear interactions for a specific game."""
if game_id in _game_interactions:
del _game_interactions[game_id]
def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None):
"""
Register AtroposClientMinimal with AI_Diplomacy's model loading system globally.
This should be called ONCE during environment setup.
Args:
queue_manager: Optional queue manager (uses global if not provided)
"""
if hasattr(clients, "_atropos_registered"):
logger.info("AtroposClientMinimal already registered globally")
return
clients._original_load_model_client = clients.load_model_client
clients._atropos_queue_manager = queue_manager or get_queue_manager()
def load_model_client_with_atropos(
model_id: str, prompts_dir: Optional[str] = None
) -> BaseModelClient:
if model_id.startswith("atropos-"):
logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}")
return AtroposClientMinimal(model_id, clients._atropos_queue_manager)
else:
logger.info(f"Falling back to original loader for {model_id}")
return clients._original_load_model_client(model_id, prompts_dir)
clients.load_model_client = load_model_client_with_atropos
clients._atropos_registered = True
logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy")
if __name__ == "__main__":
async def test_client():
client = AtroposClientMinimal(
"atropos-test",
{"base_url": "http://localhost:8000", "model_name": "test-model"},
)
test_prompts = [
"You are FRANCE. What are your orders for Spring 1901?",
"Send a message to ENGLAND about cooperation.",
"What is your strategic plan?",
]
for prompt in test_prompts:
print(f"\nPrompt: {prompt[:50]}...")
response = await client.generate_response(prompt)
print(f"Response: {response[:100]}...")
print(f"\nTracked {len(client.get_interactions())} interactions")
await client.close()
asyncio.run(test_client())