rename method

This commit is contained in:
hjc-puro 2025-07-11 00:39:28 +00:00
parent 468b599ddb
commit 3e1eba6e92

View file

@ -12,7 +12,7 @@ from atroposlib.envs.base import (
BaseEnvConfig, BaseEnvConfig,
ScoredDataGroup, ScoredDataGroup,
) )
from atroposlib.type_definitions import Item, number from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = ( system_prompt = (
@ -120,45 +120,7 @@ class GSM8kEnv(BaseEnv):
data["iter"] = self.iter data["iter"] = self.iter
super().save_checkpoint(step, data) super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, question: str, answer: str) -> number: async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
answer_parsed = parse(
completion.choices[0].message.content.split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
score = 1 if verify(answer_parsed, gold_parsed) else 0
return score
async def rollout_and_score_eval_detailed(self, question: str, answer: str) -> dict:
"""Rollout and score evaluation with detailed sample data collection.""" """Rollout and score evaluation with detailed sample data collection."""
completion = await self.server.chat_completion( completion = await self.server.chat_completion(
messages=[ messages=[
@ -234,9 +196,7 @@ class GSM8kEnv(BaseEnv):
eval_tasks = [] eval_tasks = []
for item in self.test: for item in self.test:
eval_tasks.append( eval_tasks.append(
self.rollout_and_score_eval_detailed( self.rollout_and_score_eval(item["question"], item["gold_answer"])
item["question"], item["gold_answer"]
)
) )
results = await tqdm_asyncio.gather(*eval_tasks) results = await tqdm_asyncio.gather(*eval_tasks)