mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
rename method
This commit is contained in:
parent
468b599ddb
commit
3e1eba6e92
1 changed files with 3 additions and 43 deletions
|
|
@ -12,7 +12,7 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
|
|
@ -120,45 +120,7 @@ class GSM8kEnv(BaseEnv):
|
|||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
gold_parsed = parse(
|
||||
"\\boxed{" + answer + "}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
answer_parsed = parse(
|
||||
completion.choices[0].message.content.split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
score = 1 if verify(answer_parsed, gold_parsed) else 0
|
||||
return score
|
||||
|
||||
async def rollout_and_score_eval_detailed(self, question: str, answer: str) -> dict:
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
|
||||
"""Rollout and score evaluation with detailed sample data collection."""
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
|
|
@ -234,9 +196,7 @@ class GSM8kEnv(BaseEnv):
|
|||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval_detailed(
|
||||
item["question"], item["gold_answer"]
|
||||
)
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue