atropos/environments/cat_server.py
Jonah Philion 00599f2e4b add reward
2025-05-18 17:23:11 -07:00

405 lines
16 KiB
Python

import random
import json
from typing import Dict, List, Optional, Tuple, TypedDict, Union
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)
cat_system_prompt = (
"You are a cat. The only way you can communicate is by meowing, hissing, purring, or making a hair ball, or silence."
"You will be given a collection of scenarios which describe various needs you want to be met by your caretaker."
"Please try to communicate with your caretaker through the modes outlined above."
)
cat_system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less."""
caretaker_system_prompt = (
"You are the caretaker of this cat. It is trying to communicate its various needs to you via cat language."
"Provide a written string which provides a set of interventions."
"You will only have 5 opportunities to interact with the cat. Choose what you say wisely."
)
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
You will then provide your answer like this: \\boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \\boxed{your answer here}"""
class CatRow(TypedDict):
scenario: str
class GSM8kEnv(BaseEnv):
name = "gsm8k"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=61,
batch_size=1,
steps_per_eval=60,
max_token_length=2048,
wandb_name="gsm8k",
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
async def setup(self):
# self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
# test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
with open('environments/cat_scenarios.json', 'r', encoding='utf-8') as f:
test_data = json.load(f)
self.test = list()
self.train = list()
for item in test_data:
self.test.append(
{
"scenario": item["scenario"],
# "gold_answer": item["answer"]
# .split("#")[-1]
# .strip()
# .replace(",", ""),
}
)
self.train.append(
{"scenario": item["scenario"],}
)
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, scenario: str, answer: str) -> number:
# completion = await self.server.chat_completion(
# messages=[
# {"role": "system", "content": system_prompt},
# {"role": "user", "content": scenario},
# ],
# n=1,
# max_tokens=self.config.max_token_length,
# temperature=0.0,
# split="eval",
# )
# gold_parsed = parse(
# "\\boxed{" + answer + "}",
# extraction_mode="first_match",
# extraction_config=[LatexExtractionConfig()],
# )
# answer_parsed = parse(
# completion.choices[0].message.content.split("</think>")[-1],
# extraction_config=[
# LatexExtractionConfig(
# normalization_config=NormalizationConfig(
# nits=False,
# malformed_operators=False,
# basic_latex=True,
# equations=True,
# boxed="all",
# units=True,
# ),
# # Ensures that boxed is tried first
# boxed_match_priority=0,
# try_extract_without_anchor=False,
# )
# ],
# extraction_mode="first_match",
# )
# score = 1 if verify(answer_parsed, gold_parsed) else 0
# return score
return 1
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["scenario"])
)
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def collect_trajectories(
self, item: CatRow
) -> Tuple[ScoredDataGroup, list[Item]]:
user_message = {"role": "user", "content": item["scenario"]}
# gold_answer = (
# "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
# )
cat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": cat_system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
for i, cat_completion in enumerate(cat_completions.choices):
if i == 0:
cat_message = cat_completion.message.content
caretaker_message = {"role": "user", "content": cat_message}
caretaker_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": caretaker_system_prompt}, caretaker_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, caretaker_completion in enumerate(caretaker_completions.choices):
messages = (
{"role": "system", "content": cat_system_prompt},
user_message,
{"role": "system", "content": cat_message},
{"role": "assistant", "content": caretaker_completion.message.content},
)
to_score.append(
{
"messages": messages,
}
)
to_postprocess = await self.score(to_score)
# import pdb; pdb.set_trace()
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# # random.shuffle(rollout_group_data)
for item in rollout_group_data:
final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say meow if the caretaker did a bad job, or purr if the caretaker did a good job.'}]
caretaker_completions = await self.server.chat_completion(
messages=final_question,
n=1,
max_tokens=self.config.max_token_length,
)
final_out = {'role': 'system', 'content': [row.message.content for row in caretaker_completions.choices][0]}
final_score = purrfect_eval(final_out['content'])
out_dict = tokenize_for_trainer(
self.tokenizer, [row for row in item["messages"]] + [final_out]
)
scores['tokens'].append(out_dict['tokens'])
scores['masks'].append(out_dict['masks'])
scores['scores'].append(final_score)
# tokens = out_dict["tokens"]
# masks = out_dict["masks"]
# # remove obviously bad examples
# if len([1 for i in masks if i != -100]) < 10:
# continue
# scores["tokens"].append(tokens)
# scores["masks"].append(masks)
# scores["scores"].append(1.0)
# if len(scores["tokens"]) >= self.config.group_size:
# break
# for score in scores["scores"]:
# self.percent_correct_buffer.append(max(score, 0))
# # check if all the same
# # print(scores['scores'])
# if all([score == 1 for score in scores["scores"]]):
# # Do length penalty :)
# token_lengths = [len(token) for token in scores["tokens"]]
# if max(token_lengths) == 0:
# # What? But don't want to crash a run so just in case...
# return None
# # Get max allowed token length from config
# max_allowed_length = self.config.max_token_length
# # Set threshold at 50% of max_token_length - no penalty below this
# length_threshold = max_allowed_length * 0.5
# # Apply modified length penalty with threshold
# scores["scores"] = []
# for length in token_lengths:
# if length <= length_threshold:
# # No penalty for responses under threshold
# scores["scores"].append(1.0)
# else:
# # Calculate how far we are between threshold and max as a percentage
# percentage_of_range = (length - length_threshold) / (
# max_allowed_length - length_threshold
# )
# # Cap at 1.0 in case length exceeds max_allowed_length
# percentage_of_range = min(percentage_of_range, 1.0)
# # Apply linear penalty scaling from 1.0 down to 0.0
# scores["scores"].append(1.0 - percentage_of_range)
return scores
# gold_parsed = parse(
# rollout_group_data[0]["gold_answer"],
# extraction_mode="first_match",
# extraction_config=[LatexExtractionConfig()],
# )
# if len(gold_parsed) != 0:
# # We require the answer to be provided in correct latex (no malformed operators)
# random.shuffle(rollout_group_data)
# for item in rollout_group_data:
# # print(item[0][-1]["content"])
# answer_parsed = parse(
# item["messages"][-1]["content"].split("</think>")[-1],
# extraction_config=[
# LatexExtractionConfig(
# normalization_config=NormalizationConfig(
# nits=False,
# malformed_operators=False,
# basic_latex=True,
# equations=True,
# boxed="all",
# units=True,
# ),
# # Ensures that boxed is tried first
# boxed_match_priority=0,
# try_extract_without_anchor=False,
# )
# ],
# extraction_mode="first_match",
# )
# # Reward 1 if the content is the same as the ground truth, 0 otherwise
# reward = verify(answer_parsed, gold_parsed)
# # print(
# # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
# # )
# out_dict = tokenize_for_trainer(
# self.tokenizer, item["messages"], item["finish_reason"]
# )
# tokens = out_dict["tokens"]
# masks = out_dict["masks"]
# # remove obviously bad examples
# if len([1 for i in masks if i != -100]) < 10:
# continue
# scores["tokens"].append(tokens)
# scores["masks"].append(masks)
# scores["scores"].append(1.0 if reward else -1.0)
# if len(scores["tokens"]) >= self.config.group_size:
# break
# for score in scores["scores"]:
# self.percent_correct_buffer.append(max(score, 0))
# # check if all the same
# # print(scores['scores'])
# if all([score == 1 for score in scores["scores"]]):
# # Do length penalty :)
# token_lengths = [len(token) for token in scores["tokens"]]
# if max(token_lengths) == 0:
# # What? But don't want to crash a run so just in case...
# return None
# # Get max allowed token length from config
# max_allowed_length = self.config.max_token_length
# # Set threshold at 50% of max_token_length - no penalty below this
# length_threshold = max_allowed_length * 0.5
# # Apply modified length penalty with threshold
# scores["scores"] = []
# for length in token_lengths:
# if length <= length_threshold:
# # No penalty for responses under threshold
# scores["scores"].append(1.0)
# else:
# # Calculate how far we are between threshold and max as a percentage
# percentage_of_range = (length - length_threshold) / (
# max_allowed_length - length_threshold
# )
# # Cap at 1.0 in case length exceeds max_allowed_length
# percentage_of_range = min(percentage_of_range, 1.0)
# # Apply linear penalty scaling from 1.0 down to 0.0
# scores["scores"].append(1.0 - percentage_of_range)
# if all([scores["scores"][0] == score for score in scores["scores"]]):
# return None # If all the same, we return None
# return scores
# else:
# # If the gold solution is not parseable, we return None
# return None
return None
async def get_next_item(self) -> CatRow:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"iteration: {self.iter}")
return next_item
def purrfect_eval(st: str) -> float:
if "purr" in st.lower():
return 1.0
return 0.0
if __name__ == "__main__":
GSM8kEnv.cli()