mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * Add diplomacy environment with AI_Diplomacy submodule - Add diplomacy_env_minimal.py for diplomacy game environment - Add atropos_client_minimal.py for client interface - Add diplomacy_local_server.py for local game server - Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy - Fix import ordering and remove unused imports * test file working, moving to cluster to test training * updated gitignore * removed logs * minor fixes, training running now * readded proxy reg and queue system * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * queue gameid bug, refactored * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned up configs & allowed for openrouter models to be easily used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * Remove duplicate dependencies from diplomacy requirements.txt Only keep AI_Diplomacy-specific dependencies that aren't already in the main project --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
268 lines
8.7 KiB
Python
268 lines
8.7 KiB
Python
"""
|
|
|
|
This is a queue-based proxy that:
|
|
- Intercepts LLM requests from AI_Diplomacy
|
|
- Puts them on a queue for the environment to process
|
|
- Waits for responses from the environment
|
|
- Returns responses to AI_Diplomacy
|
|
"""
|
|
|
|
import asyncio
|
|
import contextvars
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from typing import Dict, List, Optional
|
|
|
|
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy import (
|
|
clients,
|
|
)
|
|
from environments.game_environments.diplomacy_environment.queue_manager import (
|
|
PolicyRequest,
|
|
QueueManager,
|
|
get_queue_manager,
|
|
)
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
|
|
|
|
from environments.game_environments.diplomacy_environment.AI_Diplomacy.ai_diplomacy.clients import ( # noqa: E402
|
|
BaseModelClient,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
current_game_context = contextvars.ContextVar("current_game_id", default=None)
|
|
_game_interactions = {}
|
|
|
|
|
|
class AtroposClientMinimal(BaseModelClient):
|
|
"""
|
|
Queue-based proxy client that forwards LLM requests through queues.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
queue_manager: Optional[QueueManager] = None,
|
|
):
|
|
super().__init__(model_name)
|
|
self.game_id = current_game_context.get()
|
|
if not self.game_id:
|
|
raise ValueError("AtroposClientMinimal created without game context set")
|
|
|
|
self.queue_manager = queue_manager or get_queue_manager()
|
|
|
|
self.interactions: List[Dict] = []
|
|
self.current_power: Optional[str] = None
|
|
self.current_phase: Optional[str] = None
|
|
|
|
logger.info(
|
|
f"Initialized AtroposClientMinimal for {model_name} in game {self.game_id}"
|
|
)
|
|
|
|
async def generate_response(self, prompt: str, temperature: float = 0.0) -> str:
|
|
"""
|
|
Put request on queue and wait for response from environment.
|
|
This is the main method AI_Diplomacy calls for all LLM interactions.
|
|
"""
|
|
task_type = self._infer_task_type(prompt)
|
|
power = self._extract_power(prompt)
|
|
phase = self._extract_phase(prompt)
|
|
|
|
if power:
|
|
self.current_power = power
|
|
if phase:
|
|
self.current_phase = phase
|
|
|
|
logger.debug(f"Generating response for {self.current_power}: {task_type}")
|
|
|
|
try:
|
|
request_id = str(uuid.uuid4())
|
|
request = PolicyRequest(
|
|
request_id=request_id,
|
|
game_id=self.game_id,
|
|
power=self.current_power or "UNKNOWN",
|
|
phase=self.current_phase or "UNKNOWN",
|
|
prompt=prompt,
|
|
temperature=temperature,
|
|
trajectory=self.interactions.copy(),
|
|
)
|
|
|
|
await self.queue_manager.put_request(self.game_id, request)
|
|
logger.debug(f"Put request {request_id} on queue for game {self.game_id}")
|
|
|
|
response = await self.queue_manager.get_response(self.game_id)
|
|
|
|
if response.request_id != request_id:
|
|
logger.warning(
|
|
f"Response ID mismatch: expected {request_id}, got {response.request_id}"
|
|
)
|
|
|
|
response_text = response.response
|
|
|
|
# Track interaction
|
|
interaction = {
|
|
"power": self.current_power,
|
|
"phase": self.current_phase,
|
|
"task_type": task_type,
|
|
"prompt": prompt,
|
|
"response": response_text,
|
|
"metadata": response.metadata, # Store any additional info from environment
|
|
}
|
|
self.interactions.append(interaction)
|
|
|
|
if self.game_id not in _game_interactions:
|
|
_game_interactions[self.game_id] = []
|
|
_game_interactions[self.game_id].append(interaction)
|
|
|
|
return response_text
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout waiting for response from environment")
|
|
return self._generate_fallback_response(prompt)
|
|
except Exception as e:
|
|
logger.error(f"Error generating response: {e}")
|
|
return self._generate_fallback_response(prompt)
|
|
|
|
def _infer_task_type(self, prompt: str) -> str:
|
|
"""Infer the type of task from the prompt."""
|
|
prompt_lower = prompt.lower()
|
|
|
|
if "orders" in prompt_lower or "submit" in prompt_lower:
|
|
return "orders"
|
|
elif "message" in prompt_lower or "negotiate" in prompt_lower:
|
|
return "negotiation"
|
|
elif "plan" in prompt_lower or "strategy" in prompt_lower:
|
|
return "planning"
|
|
else:
|
|
return "general"
|
|
|
|
def _extract_power(self, prompt: str) -> Optional[str]:
|
|
"""Extract power name from prompt if mentioned."""
|
|
for power in [
|
|
"AUSTRIA",
|
|
"ENGLAND",
|
|
"FRANCE",
|
|
"GERMANY",
|
|
"ITALY",
|
|
"RUSSIA",
|
|
"TURKEY",
|
|
]:
|
|
if power in prompt.upper():
|
|
return power
|
|
return None
|
|
|
|
def _extract_phase(self, prompt: str) -> Optional[str]:
|
|
"""Extract game phase from prompt if mentioned."""
|
|
import re
|
|
|
|
phase_match = re.search(r"[SF]\d{4}[MRB]", prompt)
|
|
if phase_match:
|
|
return phase_match.group()
|
|
|
|
verbose_match = re.search(r"(Spring|Fall) \d{4}", prompt)
|
|
if verbose_match:
|
|
return verbose_match.group()
|
|
|
|
return None
|
|
|
|
def _generate_fallback_response(self, prompt: str) -> str:
|
|
"""Generate a simple fallback response if there's an issue."""
|
|
task_type = self._infer_task_type(prompt)
|
|
|
|
if task_type == "orders":
|
|
return json.dumps(
|
|
{
|
|
"orders": {},
|
|
"explanations": {"general": "Fallback - no server connected"},
|
|
}
|
|
)
|
|
elif task_type == "negotiation":
|
|
return json.dumps(
|
|
{
|
|
"messages": [],
|
|
"explanations": {"general": "Fallback - no server connected"},
|
|
}
|
|
)
|
|
else:
|
|
return "Fallback response - server not available"
|
|
|
|
def get_interactions(self) -> List[Dict]:
|
|
"""Get all tracked interactions for trajectory collection."""
|
|
return self.interactions
|
|
|
|
def clear_interactions(self):
|
|
"""Clear tracked interactions for a new game."""
|
|
self.interactions = []
|
|
self.current_power = None
|
|
self.current_phase = None
|
|
|
|
|
|
def get_game_interactions(game_id: str) -> List[Dict]:
|
|
"""Get all interactions for a specific game."""
|
|
return _game_interactions.get(game_id, [])
|
|
|
|
|
|
def clear_game_interactions(game_id: str):
|
|
"""Clear interactions for a specific game."""
|
|
if game_id in _game_interactions:
|
|
del _game_interactions[game_id]
|
|
|
|
|
|
def register_atropos_models_globally(queue_manager: Optional[QueueManager] = None):
|
|
"""
|
|
Register AtroposClientMinimal with AI_Diplomacy's model loading system globally.
|
|
This should be called ONCE during environment setup.
|
|
|
|
Args:
|
|
queue_manager: Optional queue manager (uses global if not provided)
|
|
"""
|
|
|
|
if hasattr(clients, "_atropos_registered"):
|
|
logger.info("AtroposClientMinimal already registered globally")
|
|
return
|
|
|
|
clients._original_load_model_client = clients.load_model_client
|
|
clients._atropos_queue_manager = queue_manager or get_queue_manager()
|
|
|
|
def load_model_client_with_atropos(
|
|
model_id: str, prompts_dir: Optional[str] = None
|
|
) -> BaseModelClient:
|
|
if model_id.startswith("atropos-"):
|
|
logger.info(f"Creating context-aware AtroposClientMinimal for {model_id}")
|
|
return AtroposClientMinimal(model_id, clients._atropos_queue_manager)
|
|
else:
|
|
logger.info(f"Falling back to original loader for {model_id}")
|
|
return clients._original_load_model_client(model_id, prompts_dir)
|
|
|
|
clients.load_model_client = load_model_client_with_atropos
|
|
clients._atropos_registered = True
|
|
|
|
logger.info("Registered AtroposClientMinimal globally with AI_Diplomacy")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
async def test_client():
|
|
client = AtroposClientMinimal(
|
|
"atropos-test",
|
|
{"base_url": "http://localhost:8000", "model_name": "test-model"},
|
|
)
|
|
|
|
test_prompts = [
|
|
"You are FRANCE. What are your orders for Spring 1901?",
|
|
"Send a message to ENGLAND about cooperation.",
|
|
"What is your strategic plan?",
|
|
]
|
|
|
|
for prompt in test_prompts:
|
|
print(f"\nPrompt: {prompt[:50]}...")
|
|
response = await client.generate_response(prompt)
|
|
print(f"Response: {response[:100]}...")
|
|
|
|
print(f"\nTracked {len(client.get_interactions())} interactions")
|
|
await client.close()
|
|
|
|
asyncio.run(test_client())
|