rename method

This commit is contained in:
hjc-puro 2025-07-11 00:39:28 +00:00
parent 468b599ddb
commit 3e1eba6e92

View file

@ -12,7 +12,7 @@ from atroposlib.envs.base import (
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
@ -120,45 +120,7 @@ class GSM8kEnv(BaseEnv):
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
answer_parsed = parse(
completion.choices[0].message.content.split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
score = 1 if verify(answer_parsed, gold_parsed) else 0
return score
async def rollout_and_score_eval_detailed(self, question: str, answer: str) -> dict:
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
"""Rollout and score evaluation with detailed sample data collection."""
completion = await self.server.chat_completion(
messages=[
@ -234,9 +196,7 @@ class GSM8kEnv(BaseEnv):
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval_detailed(
item["question"], item["gold_answer"]
)
self.rollout_and_score_eval(item["question"], item["gold_answer"])
)
results = await tqdm_asyncio.gather(*eval_tasks)