mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
- Updated point system guidelines to restrict values for Piety, Stability, Power, and Wealth to a range of -20 to 20. - Introduced a category field for scenario classification. - Added examples for scenario generation to improve clarity. - Implemented choice history tracking to maintain a record of player decisions and their effects on the kingdom's state. - Enhanced prompt formatting to include current kingdom state and recent choices for better context in scenario generation.
455 lines
17 KiB
Python
Executable file
455 lines
17 KiB
Python
Executable file
import random
|
|
import re
|
|
import json
|
|
import os
|
|
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
|
|
|
from datasets import load_dataset
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item, number
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
system_prompt = (
|
|
"You are a deep thinking AI, you may use extremely long chains of thought "
|
|
"to deeply consider the problem and deliberate with yourself via systematic "
|
|
"reasoning processes to help come to a correct solution prior to answering. "
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
|
"tags, and then provide your solution or response to the problem.\n\n"
|
|
)
|
|
|
|
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
|
|
|
|
You are playing a game called DynastAI where you generate scenarios for a kingdom management game.
|
|
Each scenario should include a character presenting a dilemma to the ruler, with two choices that affect
|
|
the four key resources of the kingdom: Piety, Stability, Power, and Wealth.
|
|
|
|
**Point System Guidelines:**
|
|
- The point values for Piety, Stability, Power, and Wealth for each choice should be integers ranging from -20 to 20.
|
|
- These values should be logically consistent with the scenario and the choice described. A choice that is clearly beneficial should have a net positive sum of points, while a detrimental choice should have a net negative sum.
|
|
- Strive for a variety of point distributions; not all resources need to be affected by every choice.
|
|
|
|
Your response must be a valid JSON object with the following structure:
|
|
{
|
|
"Character": "Name/Title of the character",
|
|
"Prompt": "The scenario description",
|
|
"Left_Choice": "The first choice option",
|
|
"Left_Piety": integer value between -20 and 20,
|
|
"Left_Stability": integer value between -20 and 20,
|
|
"Left_Power": integer value between -20 and 20,
|
|
"Left_Wealth": integer value between -20 and 20,
|
|
"Right_Choice": "The second choice option",
|
|
"Right_Piety": integer value between -20 and 20,
|
|
"Right_Stability": integer value between -20 and 20,
|
|
"Right_Power": integer value between -20 and 20,
|
|
"Right_Wealth": integer value between -20 and 20,
|
|
"category": "piety/stability/power/wealth"
|
|
}
|
|
|
|
Here are some examples:
|
|
|
|
Example 1:
|
|
{
|
|
"Character": "Diplomat",
|
|
"Prompt": "With a sly smile, the diplomat gestures broadly: \"Sire, the lords quarrel like children. Shall we mediate disputes between lords?\"",
|
|
"Left_Choice": "We cannot risk the kingdom's future; dismiss them with a royal wave.",
|
|
"Left_Piety": 10,
|
|
"Left_Stability": -10,
|
|
"Left_Power": 0,
|
|
"Left_Wealth": 0,
|
|
"Right_Choice": "Make it so; our enemies shall kneel in terror!",
|
|
"Right_Piety": -10,
|
|
"Right_Stability": 10,
|
|
"Right_Power": 0,
|
|
"Right_Wealth": 0,
|
|
"category": "stability"
|
|
}
|
|
|
|
Example 2:
|
|
{
|
|
"Character": "Merchant",
|
|
"Prompt": "The merchant nervously fidgets with coins: \"My king, the markets groan under heavy tariffs. Shall we reduce tariffs?\"",
|
|
"Left_Choice": "Absurd! Unthinkable; it's madness that courts disaster.",
|
|
"Left_Piety": 0,
|
|
"Left_Stability": -15,
|
|
"Left_Power": 0,
|
|
"Left_Wealth": 10,
|
|
"Right_Choice": "Brilliant! Most ingenious; begin before the sun sets!",
|
|
"Right_Piety": 0,
|
|
"Right_Stability": 15,
|
|
"Right_Power": 0,
|
|
"Right_Wealth": -10,
|
|
"category": "wealth"
|
|
}
|
|
|
|
Example 3:
|
|
{
|
|
"Character": "Farmer",
|
|
"Prompt": "Mud-stained and weary, the farmer removes his cap: \"Your Grace, our villages yearn for markets. Shall we hold local markets?\"",
|
|
"Left_Choice": "Silence! Such talk borders on treason; it whispers of rebellion and ruin.",
|
|
"Left_Piety": 0,
|
|
"Left_Stability": -15,
|
|
"Left_Power": 0,
|
|
"Left_Wealth": 10,
|
|
"Right_Choice": "Indeed! We shall usher wealth and fortune to the land!",
|
|
"Right_Piety": 0,
|
|
"Right_Stability": 15,
|
|
"Right_Power": 0,
|
|
"Right_Wealth": -10,
|
|
"category": "stability"
|
|
}
|
|
|
|
Be creative and make each scenario interesting!"""
|
|
|
|
|
|
class DynastAIRow(TypedDict):
|
|
scenario_prompt: str
|
|
kingdom_current_state: Optional[Dict] = None
|
|
choice_history: Optional[List] = None
|
|
|
|
|
|
class DynastAIEnv(BaseEnv):
|
|
|
|
name = "dynastai"
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.percent_correct_buffer = list()
|
|
self.eval_metrics = list()
|
|
# Add tracking for wandb visualizations
|
|
self.rollouts_for_wandb = []
|
|
self.completion_lengths = []
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
|
env_config = BaseEnvConfig(
|
|
tokenizer_name="Qwen/Qwen3-1.7B",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=12,
|
|
steps_per_eval=100,
|
|
max_token_length=2048,
|
|
wandb_name="dynastai",
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="Qwen/Qwen3-1.7B",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Try to calculate percent_correct, pass if there's a division by zero
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
# Skip if buffer is empty
|
|
pass
|
|
|
|
self.percent_correct_buffer = list()
|
|
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
self.eval_metrics = list()
|
|
# Call the parent method to handle the server metrics
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def setup(self):
|
|
# Load cards from the JSON file
|
|
cards_file = os.path.join(os.path.dirname(__file__), "dynastai_cards.json")
|
|
with open(cards_file, "r") as f:
|
|
cards = json.load(f)
|
|
|
|
# Shuffle and split into train/test
|
|
random.shuffle(cards)
|
|
test_size = int(len(cards) * 0.1) # 10% for test set
|
|
|
|
self.train = cards[test_size:]
|
|
self.test = cards[:test_size]
|
|
self.iter = 0
|
|
|
|
# Initialize default kingdom state
|
|
self.current_kingdom_state = {
|
|
"Piety": 50,
|
|
"Stability": 50,
|
|
"Power": 50,
|
|
"Wealth": 50
|
|
}
|
|
self.choice_history = []
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
data["current_kingdom_state"] = self.current_kingdom_state
|
|
data["choice_history"] = self.choice_history
|
|
super().save_checkpoint(step, data)
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
# For evaluation, we'll use the test set cards
|
|
eval_tasks = []
|
|
for card in self.test:
|
|
input_data = card.get("input", {})
|
|
kingdom_state = input_data.get("kingdom_current_state", self.current_kingdom_state)
|
|
choice_history = input_data.get("choice_history", [])
|
|
prompt = self.format_prompt(kingdom_state, choice_history)
|
|
eval_tasks.append(self.rollout_and_score_eval(prompt))
|
|
|
|
scores = await tqdm_asyncio.gather(*eval_tasks)
|
|
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
|
|
|
def format_prompt(self, kingdom_state, choice_history):
|
|
prompt = "Generate a new scenario for the kingdom with the following current state:\n"
|
|
prompt += f"Piety: {kingdom_state.get('Piety', 50)}, "
|
|
prompt += f"Stability: {kingdom_state.get('Stability', 50)}, "
|
|
prompt += f"Power: {kingdom_state.get('Power', 50)}, "
|
|
prompt += f"Wealth: {kingdom_state.get('Wealth', 50)}\n\n"
|
|
|
|
if choice_history:
|
|
prompt += "Previous choices made:\n"
|
|
for i, choice in enumerate(choice_history[-3:]): # Show last 3 choices at most
|
|
prompt += f"{i+1}. {choice.get('Character', 'Unknown')} presented: \"{choice.get('Prompt', 'Unknown')}\"\n"
|
|
prompt += f" Decision: {choice.get('choice_made', 'Unknown')}\n"
|
|
prompt += f" Effects: Piety {choice.get('effects', {}).get('Piety', 0)}, "
|
|
prompt += f"Stability {choice.get('effects', {}).get('Stability', 0)}, "
|
|
prompt += f"Power {choice.get('effects', {}).get('Power', 0)}, "
|
|
prompt += f"Wealth {choice.get('effects', {}).get('Wealth', 0)}\n\n"
|
|
|
|
prompt += "Based on this context, generate a new challenging scenario for the ruler."
|
|
return prompt
|
|
|
|
async def rollout_and_score_eval(self, scenario_prompt: str) -> number:
|
|
completion = await self.server.chat_completion(
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": scenario_prompt},
|
|
],
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0,
|
|
split="eval",
|
|
)
|
|
|
|
completion_content = completion.choices[0].message.content
|
|
score = self.validate_json_structure(completion_content)
|
|
return score
|
|
|
|
def validate_json_structure(self, content: str) -> number:
|
|
# Extract content after </think> tag if present
|
|
if "</think>" in content:
|
|
content = content.split("</think>")[-1].strip()
|
|
|
|
# Find JSON structure
|
|
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
if not json_match:
|
|
return 0
|
|
|
|
json_str = json_match.group(0)
|
|
|
|
try:
|
|
# Attempt to parse as JSON
|
|
data = json.loads(json_str)
|
|
|
|
# Check for required fields
|
|
required_fields = [
|
|
"Character", "Prompt",
|
|
"Left_Choice", "Left_Piety", "Left_Stability", "Left_Power", "Left_Wealth",
|
|
"Right_Choice", "Right_Piety", "Right_Stability", "Right_Power", "Right_Wealth",
|
|
"category"
|
|
]
|
|
|
|
if not all(field in data for field in required_fields):
|
|
return 0
|
|
|
|
# Check numeric fields
|
|
numeric_fields = [
|
|
"Left_Piety", "Left_Stability", "Left_Power", "Left_Wealth",
|
|
"Right_Piety", "Right_Stability", "Right_Power", "Right_Wealth"
|
|
]
|
|
|
|
for field in numeric_fields:
|
|
if not isinstance(data[field], int):
|
|
return 0
|
|
if data[field] < -20 or data[field] > 20:
|
|
return 0
|
|
|
|
# Check category field
|
|
if data["category"] not in ["piety", "stability", "power", "wealth"]:
|
|
return 0
|
|
|
|
# If we made it here, the JSON is valid
|
|
return 1
|
|
|
|
except json.JSONDecodeError:
|
|
return 0
|
|
|
|
async def collect_trajectories(
|
|
self, item: DynastAIRow
|
|
) -> Tuple[ScoredDataGroup, list[Item]]:
|
|
user_message = {"role": "user", "content": item["scenario_prompt"]}
|
|
|
|
chat_completions = await self.server.chat_completion(
|
|
messages=[{"role": "system", "content": system_prompt}, user_message],
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
)
|
|
|
|
to_score = []
|
|
to_backlog = []
|
|
|
|
for i, chat_completion in enumerate(chat_completions.choices):
|
|
messages = (
|
|
{"role": "system", "content": system_prompt},
|
|
user_message,
|
|
{"role": "assistant", "content": chat_completion.message.content},
|
|
)
|
|
to_score.append({
|
|
"messages": messages,
|
|
"finish_reason": chat_completion.finish_reason,
|
|
})
|
|
|
|
to_postprocess = await self.score(to_score)
|
|
|
|
# Update choice history with the highest scoring scenario
|
|
if to_postprocess and to_postprocess["scores"]:
|
|
best_idx = to_postprocess["scores"].index(max(to_postprocess["scores"]))
|
|
best_content = to_score[best_idx]["messages"][-1]["content"]
|
|
|
|
try:
|
|
# Extract JSON from content
|
|
if "</think>" in best_content:
|
|
best_content = best_content.split("</think>")[-1].strip()
|
|
json_match = re.search(r'\{.*\}', best_content, re.DOTALL)
|
|
if json_match:
|
|
json_str = json_match.group(0)
|
|
data = json.loads(json_str)
|
|
|
|
# Store the generated scenario in choice history
|
|
self.choice_history.append({
|
|
"Character": data.get("Character", "Unknown"),
|
|
"Prompt": data.get("Prompt", "Unknown"),
|
|
"choice_made": "Unknown", # Will be set when player makes a choice
|
|
"effects": {
|
|
"Piety": 0,
|
|
"Stability": 0,
|
|
"Power": 0,
|
|
"Wealth": 0
|
|
},
|
|
"category": data.get("category", "unknown"),
|
|
# Store the full scenario data for later use
|
|
"scenario_data": data
|
|
})
|
|
except Exception as e:
|
|
print(f"Error processing scenario: {e}")
|
|
|
|
return to_postprocess, to_backlog
|
|
|
|
async def score(
|
|
self, rollout_group_data
|
|
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = list()
|
|
scores["masks"] = list()
|
|
scores["scores"] = list()
|
|
|
|
random.shuffle(rollout_group_data)
|
|
for item in rollout_group_data:
|
|
completion_content = item["messages"][-1]["content"]
|
|
reward = self.validate_json_structure(completion_content)
|
|
|
|
out_dict = tokenize_for_trainer(
|
|
self.tokenizer, item["messages"], item["finish_reason"]
|
|
)
|
|
tokens = out_dict["tokens"]
|
|
masks = out_dict["masks"]
|
|
|
|
# Remove obviously bad examples
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(1.0 if reward else -1.0)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(max(score, 0))
|
|
|
|
# Check if all the same
|
|
if all([score == scores["scores"][0] for score in scores["scores"]]):
|
|
return None # If all the same, we return None
|
|
|
|
return scores
|
|
|
|
async def get_next_item(self) -> DynastAIRow:
|
|
# Increment counter
|
|
self.iter += 1
|
|
|
|
# Occasionally sample from training data, otherwise use current state
|
|
if self.train and random.random() < 0.3:
|
|
card = random.choice(self.train)
|
|
input_data = card.get("input", {})
|
|
kingdom_state = input_data.get("kingdom_current_state", self.current_kingdom_state)
|
|
choice_history = input_data.get("choice_history", [])
|
|
else:
|
|
kingdom_state = self.current_kingdom_state
|
|
choice_history = self.choice_history
|
|
|
|
# Generate prompt based on kingdom state and choice history
|
|
prompt = self.format_prompt(kingdom_state, choice_history)
|
|
|
|
return {
|
|
"scenario_prompt": prompt,
|
|
"kingdom_current_state": kingdom_state,
|
|
"choice_history": choice_history
|
|
}
|
|
|
|
# Helper method to update kingdom state based on a choice
|
|
def update_kingdom_state(self, choice, is_left_choice=True):
|
|
choice_prefix = "Left_" if is_left_choice else "Right_"
|
|
|
|
# Update the most recent choice in the history with the player's decision
|
|
if self.choice_history:
|
|
most_recent = self.choice_history[-1]
|
|
most_recent["choice_made"] = choice.get(f"{choice_prefix}Choice", "Unknown")
|
|
|
|
# Update effects based on the choice
|
|
effects = {}
|
|
for resource in ["Piety", "Stability", "Power", "Wealth"]:
|
|
value = choice.get(f"{choice_prefix}{resource}", 0)
|
|
effects[resource] = value
|
|
|
|
# Apply effect to current kingdom state
|
|
current_value = self.current_kingdom_state.get(resource, 50)
|
|
self.current_kingdom_state[resource] = max(0, min(100, current_value + value))
|
|
|
|
most_recent["effects"] = effects
|
|
|
|
|
|
if __name__ == "__main__":
|
|
DynastAIEnv.cli()
|