atropos/environments/game_environments/diplomacy_environment/diplomacy_env_minimal.py
shannonsands 46f0602227
Diplomacy trainer env (#227)
* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* Add diplomacy environment with AI_Diplomacy submodule

- Add diplomacy_env_minimal.py for diplomacy game environment
- Add atropos_client_minimal.py for client interface
- Add diplomacy_local_server.py for local game server
- Add AI_Diplomacy submodule from GoodStartLabs/AI_Diplomacy
- Fix import ordering and remove unused imports

* test file working, moving to cluster to test training

* updated gitignore

* removed logs

* minor fixes, training running now

* readded proxy reg and queue system

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* queue gameid bug, refactored

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleaned up configs & allowed for openrouter models to be easily used

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* linting

* Remove duplicate dependencies from diplomacy requirements.txt

Only keep AI_Diplomacy-specific dependencies that aren't already in the main project

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-12 09:02:16 +10:00

636 lines
22 KiB
Python

"""
Minimal Diplomacy Environment for Atropos
A simplified implementation focusing on:
- Basic game integration with AI_Diplomacy
- Parallel rollouts with group_size
- LLM proxy interception via AtroposClient
- Simple supply center based scoring
"""
import asyncio
import json
import logging
import os
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
ScoredDataItem,
)
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from environments.game_environments.diplomacy_environment.atropos_client_minimal import (
clear_game_interactions,
current_game_context,
get_game_interactions,
register_atropos_models_globally,
)
from environments.game_environments.diplomacy_environment.queue_manager import (
PolicyRequest,
PolicyResponse,
get_queue_manager,
)
sys.path.append(os.path.join(os.path.dirname(__file__), "AI_Diplomacy"))
import lm_game # noqa: E402
logger = logging.getLogger(__name__)
POWERS = ["AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"]
STARTING_SUPPLY_CENTERS = {
"AUSTRIA": 3,
"ENGLAND": 3,
"FRANCE": 3,
"GERMANY": 3,
"ITALY": 3,
"RUSSIA": 4,
"TURKEY": 3,
}
class DiplomacyEnvMinimalConfig(BaseEnvConfig):
"""Configuration for the minimal Diplomacy environment."""
env_name: str = "diplomacy_minimal"
# Game settings (minimal implementation, no randomisation stuff)
max_game_turns: int = 10
training_power: str = "FRANCE"
total_steps: int = 10
# Scoring
supply_center_weight: float = 1.0
survival_bonus: float = 0.1
win_bonus: float = 5.0
# Process management
diplomacy_server_port: int = 8432
start_diplomacy_server: bool = True
# Logging
save_game_logs: bool = True
game_logs_dir: str = "./game_logs"
# Evaluation
eval_episodes: int = 10
# Opponent models (None = use SGLang endpoints from server_configs)
opponent_models: Optional[List[str]] = None
class DiplomacyEnvMinimal(BaseEnv):
name = "diplomacy_minimal"
env_config_cls = DiplomacyEnvMinimalConfig
def __init__(
self,
config: DiplomacyEnvMinimalConfig,
server_configs: List[APIServerConfig],
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: DiplomacyEnvMinimalConfig = config
self.game_server_process: Optional[subprocess.Popen] = None
self.game_outcomes_buffer: List[Dict] = []
self.eval_metrics_custom: List[Tuple[str, float]] = []
self.queue_manager = get_queue_manager()
self.active_games: Dict[str, Dict] = {}
if config.save_game_logs:
Path(config.game_logs_dir).mkdir(exist_ok=True)
self.system_prompt = (
f"You are playing Diplomacy as {config.training_power}. "
"Analyze the game state and respond with your strategy and orders."
)
if config.opponent_models:
self.opponent_models = config.opponent_models
else:
self.opponent_models = []
for server_config in server_configs:
model_spec = f"openai:{server_config.model_name}@{server_config.base_url}#{server_config.api_key}"
self.opponent_models.append(model_spec)
@classmethod
def config_init(cls) -> Tuple[DiplomacyEnvMinimalConfig, List[APIServerConfig]]:
"""Initialize default configuration."""
env_config = DiplomacyEnvMinimalConfig(
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
group_size=4,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=4096,
wandb_name=cls.name,
steps_per_eval=20,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9005/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9006/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9007/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
async def setup(self):
"""Set up the environment."""
logger.info(f"Setting up {self.name} environment")
register_atropos_models_globally(self.queue_manager)
if self.config.start_diplomacy_server:
await self._start_diplomacy_server()
asyncio.create_task(self._poll_request_queues())
async def _poll_request_queues(self):
"""Poll request queues and handle policy requests."""
while True:
try:
for game_id in list(self.active_games.keys()):
queue_pair = self.queue_manager.get_queue_pair(game_id)
if not queue_pair:
continue
try:
request = queue_pair.request_queue.get_nowait()
asyncio.create_task(self._handle_policy_request(request))
except asyncio.QueueEmpty:
pass
await asyncio.sleep(0.01)
except Exception as e:
logger.error(f"Error in queue polling: {e}")
await asyncio.sleep(1)
async def _handle_policy_request(self, request: PolicyRequest):
"""Handle a single policy request by sampling from SGLang."""
try:
logger.info(
f"Handling request {request.request_id} for {request.power} in game {request.game_id}"
)
messages = [{"role": "system", "content": self.system_prompt}]
for interaction in request.trajectory:
messages.append({"role": "user", "content": interaction["prompt"]})
messages.append(
{"role": "assistant", "content": interaction["response"]}
)
messages.append({"role": "user", "content": request.prompt})
async with self.server.dedicated_server() as server:
response = await server.chat_completion(
messages=messages,
n=1,
temperature=request.temperature,
max_tokens=2000,
)
response_text = response.choices[0].message.content.strip()
policy_response = PolicyResponse(
request_id=request.request_id,
response=response_text,
metadata={
"power": request.power,
"phase": request.phase,
},
)
await self.queue_manager.put_response(request.game_id, policy_response)
logger.debug(f"Sent response for request {request.request_id}")
except Exception as e:
logger.error(f"Error handling policy request: {e}")
error_response = PolicyResponse(
request_id=request.request_id,
response="Error: Failed to generate response",
metadata={"error": str(e)},
)
await self.queue_manager.put_response(request.game_id, error_response)
async def _start_diplomacy_server(self):
"""Start the AI_Diplomacy game server."""
try:
logger.info(
f"Starting Diplomacy server on port {self.config.diplomacy_server_port}"
)
self.game_server_process = subprocess.Popen(
[
sys.executable,
"-m",
"diplomacy.server.run",
"--port",
str(self.config.diplomacy_server_port),
],
cwd=os.path.join(os.path.dirname(__file__), "AI_Diplomacy"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Give server time to start
await asyncio.sleep(3)
logger.info(
f"Diplomacy server started at http://localhost:{self.config.diplomacy_server_port}"
)
except Exception as e:
logger.error(f"Failed to start Diplomacy server: {e}")
async def collect_trajectories(
self, item: Item
) -> Tuple[ScoredDataGroup, List[Item]]:
"""
Run parallel Diplomacy games and collect all trajectories.
This implements the key RL training pattern:
1. Run group_size parallel games with the same seed
2. Each game explores different action sequences
3. Score each trajectory based on game outcome
4. Return all trajectories as a ScoredDataGroup for training
"""
logger.warning(
f"[DiplomacyEnvMinimal] collect_trajectories called with item: {item}"
)
base_game_id = item.get("game_id", f"game-{int(time.time())}")
seed = item.get("seed", random.randint(0, 1_000_000))
logger.info(
f"Starting {self.config.group_size} parallel games with seed {seed}"
)
# Run parallel games w/ same seed
game_tasks = []
for i in range(self.config.group_size):
game_id = f"{base_game_id}-{i}"
task = self._run_single_game(game_id, seed, trajectory_id=i)
game_tasks.append(task)
results = await asyncio.gather(*game_tasks, return_exceptions=True)
scored_items = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"Game {i} failed: {result}")
continue
if result and result[0]:
scored_items.append(result[0])
logger.warning(
f"[DiplomacyEnvMinimal] Collected {len(scored_items)} scored items"
)
if not scored_items:
logger.error("No valid trajectories collected")
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
),
[],
)
sdg = ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
)
for scored_item in scored_items:
sdg["tokens"].append(scored_item["tokens"])
sdg["masks"].append(scored_item["masks"])
sdg["scores"].append(scored_item["scores"])
if self.config.include_messages and scored_item.get("messages"):
sdg["messages"].append(scored_item["messages"])
logger.info(f"Collected {len(scored_items)} trajectories")
logger.warning(
f"[DiplomacyEnvMinimal] Returning ScoredDataGroup with {len(sdg['tokens'])} "
f"tokens, {len(sdg['scores'])} scores"
)
logger.warning(
f"[DiplomacyEnvMinimal] First few scores: {sdg['scores'][:5] if sdg['scores'] else 'None'}"
)
for i in range(self.config.group_size):
game_id = f"{base_game_id}-{i}"
if game_id in self.active_games:
del self.active_games[game_id]
try:
await self.queue_manager.remove_game_queues(game_id)
except Exception as e:
logger.debug(f"Error cleaning up queues for {game_id}: {e}")
return sdg, []
async def _run_single_game(
self, game_id: str, seed: int, trajectory_id: int
) -> Tuple[Optional[ScoredDataItem], None]:
"""
Run a single Diplomacy game and return scored trajectory.
"""
try:
queue_pair = await self.queue_manager.create_game_queues(game_id)
self.active_games[game_id] = {
"queue_pair": queue_pair,
"start_time": time.time(),
"interactions": [],
}
token = current_game_context.set(game_id)
try:
game_result = await self._run_diplomacy_game(
game_id, seed, trajectory_id
)
finally:
current_game_context.reset(token)
if not game_result:
logger.error(f"Game {game_id} failed to complete")
return None, None
score = self._calculate_score(game_result, self.config.training_power)
interactions = get_game_interactions(game_id)
training_interactions = [
i for i in interactions if i.get("power") == self.config.training_power
]
if training_interactions:
messages = [{"role": "system", "content": self.system_prompt}]
for interaction in training_interactions:
messages.append({"role": "user", "content": interaction["prompt"]})
messages.append(
{"role": "assistant", "content": interaction["response"]}
)
logger.info(
f"Collected {len(training_interactions)} interactions for {self.config.training_power}"
)
else:
logger.warning(
f"No interactions found for {self.config.training_power} in game {game_id}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"Playing Diplomacy game {game_id}"},
{
"role": "assistant",
"content": f"Game completed with score {score:.2f}",
},
]
clear_game_interactions(game_id)
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True,
)
scored_data_item = ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokenization_result["tokens"],
masks=tokenization_result["masks"],
scores=score,
)
self.game_outcomes_buffer.append(
{
"game_id": game_id,
"score": score,
"winner": game_result.get("winner"),
"turns": game_result.get("turns_played", 0),
"final_centers": game_result.get("final_centers", {}),
}
)
return scored_data_item, None
except Exception as e:
logger.error(f"Error in game {game_id}: {e}", exc_info=True)
return None, None
async def _run_diplomacy_game(
self, game_id: str, seed: int, trajectory_id: int = 0
) -> Optional[Dict]:
"""Run a Diplomacy game using AI_Diplomacy's lm_game module."""
game_output_path = os.path.join(self.config.game_logs_dir, f"{game_id}.json")
models = []
opponent_idx = 0
for power in POWERS:
if power == self.config.training_power:
models.append("atropos-training-policy")
else:
models.append(
self.opponent_models[opponent_idx % len(self.opponent_models)]
)
opponent_idx += 1
original_argv = sys.argv
try:
sys.argv = [
"lm_game.py",
"--models",
",".join(models),
"--max_year",
str(1900 + self.config.max_game_turns),
"--output",
game_output_path,
"--seed",
str(seed + trajectory_id),
"--num_negotiation_rounds",
"0",
]
await lm_game.main()
actual_game_file = os.path.join(game_output_path, "lmvsgame.json")
if os.path.exists(actual_game_file):
with open(actual_game_file, "r") as f:
saved_game = json.load(f)
phases = saved_game.get("phases", [])
last_phase = phases[-1] if phases else {}
result = {
"winner": saved_game.get("winner"),
"turns_played": len(phases),
"final_centers": {},
}
for power in POWERS:
centers = (
last_phase.get("state", {}).get("centers", {}).get(power, [])
)
result["final_centers"][power] = len(centers)
return result
else:
logger.error(f"Game output not found: {actual_game_file}")
return None
finally:
sys.argv = original_argv
def _calculate_score(self, game_result: Dict, power: str) -> float:
"""Calculate score for a power based on game outcome."""
start_centers = STARTING_SUPPLY_CENTERS[power]
end_centers = game_result["final_centers"].get(power, 0)
center_score = (end_centers - start_centers) * self.config.supply_center_weight
survival_score = self.config.survival_bonus * game_result["turns_played"]
win_score = self.config.win_bonus if game_result["winner"] == power else 0.0
total_score = center_score + survival_score + win_score
logger.info(
f"{power} score: centers={center_score:.2f}, "
f"survival={survival_score:.2f}, win={win_score:.2f}, "
f"total={total_score:.2f}"
)
return total_score
async def get_next_item(self) -> Item:
"""Generate configuration for the next game."""
return {
"game_id": f"game-{int(time.time())}-{random.randint(1000, 9999)}",
"seed": random.randint(0, 1_000_000),
}
async def evaluate(self, *args, **kwargs):
"""Run evaluation games."""
logger.info(f"Starting evaluation with {self.config.eval_episodes} episodes")
eval_scores = []
wins = 0
for i in range(self.config.eval_episodes):
item = await self.get_next_item()
item["is_eval"] = True
scored_data_group, _ = await self.collect_trajectories(item)
if scored_data_group and scored_data_group["scores"]:
avg_score = sum(scored_data_group["scores"]) / len(
scored_data_group["scores"]
)
eval_scores.append(avg_score)
if (
self.game_outcomes_buffer
and self.game_outcomes_buffer[-1]["winner"]
== self.config.training_power
):
wins += 1
if eval_scores:
avg_score = sum(eval_scores) / len(eval_scores)
win_rate = wins / self.config.eval_episodes
self.eval_metrics_custom = [
(f"{self.name}_eval/avg_score", avg_score),
(f"{self.name}_eval/win_rate", win_rate),
(f"{self.name}_eval/num_completed", len(eval_scores)),
]
logger.info(
f"Evaluation completed: avg_score={avg_score:.2f}, win_rate={win_rate:.2%}"
)
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
if self.game_outcomes_buffer:
total_games = len(self.game_outcomes_buffer)
avg_score = sum(g["score"] for g in self.game_outcomes_buffer) / total_games
wins = sum(
1
for g in self.game_outcomes_buffer
if g["winner"] == self.config.training_power
)
win_rate = wins / total_games
avg_turns = sum(g["turns"] for g in self.game_outcomes_buffer) / total_games
wandb_metrics.update(
{
f"{self.name}_train/avg_score": avg_score,
f"{self.name}_train/win_rate": win_rate,
f"{self.name}_train/avg_turns": avg_turns,
f"{self.name}_train/num_games": total_games,
}
)
self.game_outcomes_buffer = []
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
def __del__(self):
"""Clean up server process on exit."""
if self.game_server_process:
self.game_server_process.terminate()
self.game_server_process.wait()
if __name__ == "__main__":
DiplomacyEnvMinimal.cli()