mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
447
environments/math_server_zero.py
Normal file
447
environments/math_server_zero.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""
|
||||
This file contains code inspired by and adapted from the Open-Reasoner-Zero project.
|
||||
Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from math_verify.errors import TimeoutException
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
|
||||
prompt_format = (
|
||||
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
|
||||
"first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning "
|
||||
"process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, "
|
||||
"i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {prompt}\nAssistant: <think>"
|
||||
)
|
||||
|
||||
problem_format = """You must put your answer inside <answer> </answer> tags, i.e., <answer> answer here </answer>. And your final answer will be extracted automatically by the \\boxed{{}} tag.
|
||||
This is the problem:
|
||||
{problem}
|
||||
""" # noqa: E501
|
||||
|
||||
stop_list = ["User:", "Human:", "Assistant:", "</answer>"]
|
||||
|
||||
|
||||
class RSConfig(BaseEnvConfig):
|
||||
run_evaluation: bool = Field(True, description="If this should run evaluation")
|
||||
mask_too_long_completions: bool = Field(
|
||||
True, description="If this should mask too long completions"
|
||||
)
|
||||
percent_length_penalty: float = Field(
|
||||
0.0, description="The percentage of items to have length penalty"
|
||||
)
|
||||
|
||||
|
||||
def score_answer(gold, resp) -> Optional[bool]:
|
||||
pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
|
||||
matches = pattern.findall(resp)
|
||||
resp = matches[-1] if matches else None
|
||||
if resp is None:
|
||||
return False
|
||||
try:
|
||||
gold_parsed = parse(
|
||||
gold,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError):
|
||||
return None
|
||||
if len(gold_parsed) != 0:
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
resp,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
except (
|
||||
Exception,
|
||||
TimeoutException,
|
||||
KeyError,
|
||||
TypeError,
|
||||
NotImplementedError,
|
||||
):
|
||||
# Can't parse, so we skip
|
||||
return None
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
try:
|
||||
return verify(answer_parsed, gold_parsed)
|
||||
except (
|
||||
Exception,
|
||||
TimeoutException,
|
||||
KeyError,
|
||||
TypeError,
|
||||
NotImplementedError,
|
||||
):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class MathEnv(BaseEnv):
|
||||
|
||||
name = "math"
|
||||
env_config_cls = RSConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RSConfig,
|
||||
server_configs: List[OpenaiConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
print("Initializing MathEnv")
|
||||
print(f"Slurm: {slurm}, Testing: {testing}")
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
self.mp_executor = ProcessPoolExecutor(64)
|
||||
self.percent_overanswer = list()
|
||||
self.correct_answer_len = list()
|
||||
self.incorrect_answer_len = list()
|
||||
self.normal_rollouts = list()
|
||||
self.pass_at_groupsize = list()
|
||||
self.iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[RSConfig, List[OpenaiConfig]]:
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-7B",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=25,
|
||||
max_token_length=31000, # 22000 // (2 ** i),
|
||||
wandb_name="math",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
)
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="default",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = dict()
|
||||
if len(self.pass_at_groupsize) > 0:
|
||||
wandb_metrics["train/pass_at_groupsize"] = sum(
|
||||
self.pass_at_groupsize
|
||||
) / len(self.pass_at_groupsize)
|
||||
self.pass_at_8 = list()
|
||||
if len(self.percent_correct_buffer) > 0:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
wandb_metrics["train/percent_overanswer"] = sum(
|
||||
self.percent_overanswer
|
||||
) / len(self.percent_overanswer)
|
||||
self.percent_overthink = list()
|
||||
self.percent_overanswer = list()
|
||||
self.percent_correct_buffer = list()
|
||||
if len(self.correct_answer_len) > 0:
|
||||
wandb_metrics["train/avg_correct_answer_len"] = sum(
|
||||
self.correct_answer_len
|
||||
) / len(self.correct_answer_len)
|
||||
self.correct_answer_len = list()
|
||||
if len(self.incorrect_answer_len) > 0:
|
||||
wandb_metrics["train/avg_incorrect_answer_len"] = sum(
|
||||
self.incorrect_answer_len
|
||||
) / len(self.incorrect_answer_len)
|
||||
self.incorrect_answer_len = list()
|
||||
# create tables
|
||||
if len(self.normal_rollouts) > 0:
|
||||
table = wandb.Table(columns=["problem", "solution", "answer", "score"])
|
||||
for group in self.normal_rollouts:
|
||||
table.add_data(group[0], group[1], group[2], group[3])
|
||||
wandb_metrics["train/normal_rollouts"] = table
|
||||
wandb_metrics["train/iter"] = self.iter
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle(
|
||||
seed=42
|
||||
)
|
||||
aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train")
|
||||
math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test")
|
||||
amc_test_data = load_dataset("math-ai/amc23", split="test")
|
||||
minerva_test_data = load_dataset("math-ai/minervamath", split="test")
|
||||
olympiad_test_data = load_dataset("math-ai/olympiadbench", split="test")
|
||||
self.test = list()
|
||||
for name, t_dataset in zip(
|
||||
["aime24", "math500"], [aime_test_data, math500_test_data]
|
||||
):
|
||||
for item in t_dataset:
|
||||
self.test.append(
|
||||
(
|
||||
prompt_format.format(
|
||||
prompt=problem_format.format(problem=item["problem"])
|
||||
),
|
||||
item["answer"],
|
||||
name,
|
||||
)
|
||||
)
|
||||
for name, t_dataset in zip(
|
||||
["amc23", "minerva", "olympiad"],
|
||||
[amc_test_data, minerva_test_data, olympiad_test_data],
|
||||
):
|
||||
for item in t_dataset:
|
||||
self.test.append(
|
||||
(
|
||||
prompt_format.format(
|
||||
prompt=problem_format.format(problem=item["question"])
|
||||
),
|
||||
item["answer"],
|
||||
name,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
async def rollout_and_score_eval(self, question, answer, subset):
|
||||
|
||||
completion = await self.server.completion(
|
||||
prompt=question,
|
||||
n=1,
|
||||
max_tokens=32765,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
stop=stop_list,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
|
||||
resp = completion.choices[0].text
|
||||
if completion.choices[0].finish_reason == "stop":
|
||||
if ("</answer>" not in completion.choices[0].text) and (
|
||||
"<answer>" in completion.choices[0].text
|
||||
):
|
||||
# assume it stopped on </answer>
|
||||
resp = resp + " </answer>"
|
||||
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
|
||||
reward = await task
|
||||
if reward is None:
|
||||
return 0, subset
|
||||
score = 1 if reward else 0
|
||||
return score, subset
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
if not self.config.run_evaluation:
|
||||
return
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2]))
|
||||
parsing_data = await tqdm_asyncio.gather(*eval_tasks)
|
||||
task_lists = dict()
|
||||
for score, subset in parsing_data:
|
||||
if subset not in task_lists:
|
||||
task_lists[subset] = list()
|
||||
task_lists[subset].append(score)
|
||||
# Now get the average
|
||||
for subset, scores in task_lists.items():
|
||||
self.eval_metrics.append(
|
||||
(f"eval/{subset}_percent_correct", sum(scores) / len(scores))
|
||||
)
|
||||
# overall score
|
||||
scores = []
|
||||
for subset, score in task_lists.items():
|
||||
scores.extend(score)
|
||||
self.eval_metrics.append(
|
||||
("eval/overall_percent_correct", sum(scores) / len(scores))
|
||||
)
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
thinking_len = self.config.max_token_length
|
||||
user_prompt = prompt_format.format(
|
||||
prompt=problem_format.format(problem=item[0])
|
||||
)
|
||||
thinking_len = thinking_len - len(self.tokenizer.encode(user_prompt))
|
||||
completions = await self.server.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
stop=stop_list,
|
||||
)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, completion in enumerate(completions.choices):
|
||||
message = user_prompt + completion.text
|
||||
if completion.finish_reason == "stop":
|
||||
if ("</answer>" not in completion.text) and (
|
||||
"<answer>" in completion.text
|
||||
):
|
||||
# assume it stopped on </answer>
|
||||
message = message + " </answer>"
|
||||
to_score.append(
|
||||
(
|
||||
message,
|
||||
item[1],
|
||||
completion.finish_reason,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
if to_postprocess is None:
|
||||
return None, to_backlog
|
||||
if all(
|
||||
[to_postprocess["scores"][0] == score for score in to_postprocess["scores"]]
|
||||
):
|
||||
return None, to_backlog
|
||||
self.normal_rollouts.append(
|
||||
(
|
||||
prompt_format.format(prompt=problem_format.format(problem=item[0])),
|
||||
to_postprocess["messages"][0][-1]["content"],
|
||||
item[1],
|
||||
to_postprocess["scores"][0],
|
||||
)
|
||||
)
|
||||
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
|
||||
self.normal_rollouts.pop(0)
|
||||
print(f"Collected {len(to_postprocess['scores'])} trajectories")
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["overrides"] = list()
|
||||
scores["messages"] = list()
|
||||
gold = rollout_group_data[0][1]
|
||||
loop = asyncio.get_event_loop()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
resp = item[0]
|
||||
scores["overrides"].append(dict())
|
||||
if item[2] == "length":
|
||||
reward = False
|
||||
if self.config.mask_too_long_completions:
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
else:
|
||||
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
|
||||
reward = await task
|
||||
if reward is None:
|
||||
return None
|
||||
tokens = self.tokenizer.encode(resp)
|
||||
user_prompt_tokens = self.tokenizer.encode(item[3])
|
||||
if user_prompt_tokens[-1] == self.tokenizer.eos_token_id:
|
||||
user_prompt_tokens = user_prompt_tokens[:-1]
|
||||
assert all(
|
||||
[
|
||||
i == j
|
||||
for i, j in zip(
|
||||
user_prompt_tokens, tokens[: len(user_prompt_tokens)]
|
||||
)
|
||||
]
|
||||
)
|
||||
masks = [-100 for _ in range(len(user_prompt_tokens))]
|
||||
masks = masks + tokens[len(user_prompt_tokens) :]
|
||||
messages = [
|
||||
{"role": "user", "content": item[3]},
|
||||
{"role": "assistant", "content": resp[len(item[3]) :]},
|
||||
]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
if (item[2] == "length") and (not self.config.mask_too_long_completions):
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
scores["messages"].append(messages)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
if any([score == 1.0 for score in scores["scores"]]):
|
||||
self.pass_at_groupsize.append(1.0)
|
||||
else:
|
||||
self.pass_at_groupsize.append(0.0)
|
||||
if len(scores["tokens"]) < self.config.group_size:
|
||||
# We don't have enough data to score
|
||||
return None
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
self.percent_overanswer.extend(
|
||||
[item[2] == "length" for item in rollout_group_data]
|
||||
)
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
# Fill in the correct/incorrect lens after so we're only looking at actual training data
|
||||
self.correct_answer_len.extend(
|
||||
[
|
||||
len(scores["tokens"][i])
|
||||
for i in range(len(scores["scores"]))
|
||||
if scores["scores"][i] == 1.0
|
||||
]
|
||||
)
|
||||
self.incorrect_answer_len.extend(
|
||||
[
|
||||
len(scores["tokens"][i])
|
||||
for i in range(len(scores["scores"]))
|
||||
if (scores["scores"][i] == -1.0)
|
||||
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
|
||||
]
|
||||
)
|
||||
return scores
|
||||
|
||||
async def get_next_item(self):
|
||||
while True:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
prompt = next_item["question"]
|
||||
try:
|
||||
answer = (
|
||||
("\\boxed{" + next_item["final_answer"] + "}")
|
||||
if "\\boxed" not in next_item["final_answer"]
|
||||
else next_item["final_answer"]
|
||||
)
|
||||
break
|
||||
except TypeError:
|
||||
print(
|
||||
f"Error in getting next item, trying again, "
|
||||
f"data: {next_item['question']} -> {next_item['final_answer']}"
|
||||
)
|
||||
return (prompt, answer, "normal")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MathEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue