mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Introduces `log_eval_sample()` method for stream-writing individual evaluation samples to `samples.jsonl` during evaluation, with lazy writer initialization and automatic HTML generation on completion. Updates GSM8k environment to use streaming approach instead of batching samples.
363 lines
13 KiB
Python
363 lines
13 KiB
Python
import random
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
|
|
|
from datasets import load_dataset
|
|
from latex2sympy2_extended import NormalizationConfig
|
|
from math_verify import LatexExtractionConfig, parse, verify
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
ServerBaseline,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
|
|
system_prompt = (
|
|
"You are a deep thinking AI, you may use extremely long chains of thought "
|
|
"to deeply consider the problem and deliberate with yourself via systematic "
|
|
"reasoning processes to help come to a correct solution prior to answering. "
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
|
"tags, and then provide your solution or response to the problem.\n\n"
|
|
)
|
|
|
|
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
|
|
|
|
You will then provide your answer like this: \\boxed{your answer here}
|
|
It is important that you provide your answer in the correct format.
|
|
If you do not, you will not receive credit for your answer.
|
|
So please end your answer with \\boxed{your answer here}"""
|
|
|
|
|
|
class GSM8kRow(TypedDict):
|
|
question: str
|
|
answer: str
|
|
|
|
|
|
class GSM8kEnv(BaseEnv):
|
|
|
|
name = "gsm8k"
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.percent_correct_buffer = list()
|
|
self.eval_metrics = list()
|
|
# Add tracking for wandb visualizations
|
|
self.rollouts_for_wandb = []
|
|
self.completion_lengths = []
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, ServerBaseline]:
|
|
env_config = BaseEnvConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=12,
|
|
steps_per_eval=100,
|
|
max_token_length=2048,
|
|
wandb_name="gsm8k",
|
|
)
|
|
server_config = APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
)
|
|
|
|
return env_config, server_config
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Try to calculate percent_correct, pass if there's a division by zero
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
# Skip if buffer is empty
|
|
pass
|
|
|
|
self.percent_correct_buffer = list()
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
self.eval_metrics = list()
|
|
# Call the parent method to handle the server metrics
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def setup(self):
|
|
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
|
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
|
self.test = list()
|
|
for item in test_data:
|
|
self.test.append(
|
|
{
|
|
"question": item["question"],
|
|
"gold_answer": item["answer"]
|
|
.split("#")[-1]
|
|
.strip()
|
|
.replace(",", ""),
|
|
}
|
|
)
|
|
self.iter = 0
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
|
|
"""Rollout and score evaluation with detailed sample data collection."""
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
completion = await managed.chat_completion(
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": question},
|
|
],
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.6,
|
|
)
|
|
|
|
response_content = completion.choices[0].message.content
|
|
|
|
# Parse gold answer
|
|
gold_parsed = parse(
|
|
"\\boxed{" + answer + "}",
|
|
extraction_mode="first_match",
|
|
extraction_config=[LatexExtractionConfig()],
|
|
)
|
|
|
|
# Parse model answer
|
|
answer_parsed = parse(
|
|
response_content.split("</think>")[-1],
|
|
extraction_config=[
|
|
LatexExtractionConfig(
|
|
normalization_config=NormalizationConfig(
|
|
nits=False,
|
|
malformed_operators=False,
|
|
basic_latex=True,
|
|
equations=True,
|
|
boxed="all",
|
|
units=True,
|
|
),
|
|
boxed_match_priority=0,
|
|
try_extract_without_anchor=False,
|
|
)
|
|
],
|
|
extraction_mode="first_match",
|
|
)
|
|
|
|
score = 1 if verify(answer_parsed, gold_parsed) else 0
|
|
|
|
sample = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": question},
|
|
{"role": "assistant", "content": response_content},
|
|
],
|
|
"question": question,
|
|
"gold_answer": answer,
|
|
"gold_parsed": str(gold_parsed) if gold_parsed else None,
|
|
"model_parsed": str(answer_parsed) if answer_parsed else None,
|
|
"score": int(score),
|
|
"correct": bool(score),
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
"response_after_think": (
|
|
response_content.split("</think>")[-1]
|
|
if "</think>" in response_content
|
|
else response_content
|
|
),
|
|
}
|
|
|
|
return {"score": score, "sample": sample}
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
start_time = time.time()
|
|
|
|
async def rollout_and_log(item):
|
|
result = await self.rollout_and_score_eval(
|
|
item["question"], item["gold_answer"]
|
|
)
|
|
if result is not None:
|
|
self.log_eval_sample(result.get("sample", result))
|
|
return result
|
|
|
|
eval_tasks = [rollout_and_log(item) for item in self.test]
|
|
results = await tqdm_asyncio.gather(*eval_tasks)
|
|
|
|
scores = [result["score"] for result in results]
|
|
percent_correct = sum(scores) / len(scores)
|
|
|
|
end_time = time.time()
|
|
|
|
# Add to existing metrics for wandb
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
|
|
await self.evaluate_log(
|
|
metrics={"eval/percent_correct": percent_correct},
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
generation_parameters={
|
|
"temperature": 0.0,
|
|
"max_tokens": self.config.max_token_length,
|
|
},
|
|
)
|
|
|
|
async def collect_trajectories(
|
|
self, item: GSM8kRow
|
|
) -> Tuple[ScoredDataGroup, list[Item]]:
|
|
user_message = {"role": "user", "content": item["question"]}
|
|
gold_answer = (
|
|
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
|
)
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
|
|
chat_completions = await managed.chat_completion(
|
|
messages=[{"role": "system", "content": system_prompt}, user_message],
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=1.0,
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
to_score = list()
|
|
to_backlog = list()
|
|
for i, chat_completion in enumerate(chat_completions.choices):
|
|
messages = (
|
|
{"role": "system", "content": system_prompt},
|
|
user_message,
|
|
{"role": "assistant", "content": chat_completion.message.content},
|
|
)
|
|
to_score.append(
|
|
{
|
|
"messages": messages,
|
|
"gold_answer": gold_answer,
|
|
"finish_reason": chat_completion.finish_reason,
|
|
"tokens": nodes[i].tokens,
|
|
"masks": nodes[i].masked_tokens,
|
|
"logprobs": nodes[i].logprobs,
|
|
}
|
|
)
|
|
to_postprocess = await self.score(to_score)
|
|
return to_postprocess, to_backlog
|
|
|
|
async def score(
|
|
self, rollout_group_data
|
|
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = list()
|
|
scores["masks"] = list()
|
|
scores["scores"] = list()
|
|
scores["inference_logprobs"] = list()
|
|
gold_parsed = parse(
|
|
rollout_group_data[0]["gold_answer"],
|
|
extraction_mode="first_match",
|
|
extraction_config=[LatexExtractionConfig()],
|
|
)
|
|
if len(gold_parsed) != 0:
|
|
# We require the answer to be provided in correct latex (no malformed operators)
|
|
random.shuffle(rollout_group_data)
|
|
for item in rollout_group_data:
|
|
# print(item[0][-1]["content"])
|
|
answer_parsed = parse(
|
|
item["messages"][-1]["content"].split("</think>")[-1],
|
|
extraction_config=[
|
|
LatexExtractionConfig(
|
|
normalization_config=NormalizationConfig(
|
|
nits=False,
|
|
malformed_operators=False,
|
|
basic_latex=True,
|
|
equations=True,
|
|
boxed="all",
|
|
units=True,
|
|
),
|
|
# Ensures that boxed is tried first
|
|
boxed_match_priority=0,
|
|
try_extract_without_anchor=False,
|
|
)
|
|
],
|
|
extraction_mode="first_match",
|
|
)
|
|
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
|
reward = verify(answer_parsed, gold_parsed)
|
|
|
|
tokens = item["tokens"]
|
|
masks = item["masks"]
|
|
logprobs = item["logprobs"]
|
|
|
|
# remove obviously bad examples
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["inference_logprobs"].append(logprobs)
|
|
scores["scores"].append(1.0 if reward else -1.0)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(max(score, 0))
|
|
|
|
# check if all the same
|
|
# print(scores['scores'])
|
|
if all([score == 1 for score in scores["scores"]]):
|
|
# Do length penalty :)
|
|
token_lengths = [len(token) for token in scores["tokens"]]
|
|
if max(token_lengths) == 0:
|
|
# What? But don't want to crash a run so just in case...
|
|
return None
|
|
|
|
# Get max allowed token length from config
|
|
max_allowed_length = self.config.max_token_length
|
|
# Set threshold at 50% of max_token_length - no penalty below this
|
|
length_threshold = max_allowed_length * 0.5
|
|
|
|
# Apply modified length penalty with threshold
|
|
scores["scores"] = []
|
|
for length in token_lengths:
|
|
if length <= length_threshold:
|
|
# No penalty for responses under threshold
|
|
scores["scores"].append(1.0)
|
|
else:
|
|
# Calculate how far we are between threshold and max as a percentage
|
|
percentage_of_range = (length - length_threshold) / (
|
|
max_allowed_length - length_threshold
|
|
)
|
|
# Cap at 1.0 in case length exceeds max_allowed_length
|
|
percentage_of_range = min(percentage_of_range, 1.0)
|
|
# Apply linear penalty scaling from 1.0 down to 0.0
|
|
scores["scores"].append(1.0 - percentage_of_range)
|
|
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
|
return None # If all the same, we return None
|
|
return scores
|
|
else:
|
|
# If the gold solution is not parseable, we return None
|
|
return None
|
|
|
|
async def get_next_item(self) -> GSM8kRow:
|
|
next_item = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
return next_item
|
|
|
|
|
|
if __name__ == "__main__":
|
|
GSM8kEnv.cli()
|