mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
rename method
This commit is contained in:
parent
468b599ddb
commit
3e1eba6e92
1 changed files with 3 additions and 43 deletions
|
|
@ -12,7 +12,7 @@ from atroposlib.envs.base import (
|
||||||
BaseEnvConfig,
|
BaseEnvConfig,
|
||||||
ScoredDataGroup,
|
ScoredDataGroup,
|
||||||
)
|
)
|
||||||
from atroposlib.type_definitions import Item, number
|
from atroposlib.type_definitions import Item
|
||||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
|
|
@ -120,45 +120,7 @@ class GSM8kEnv(BaseEnv):
|
||||||
data["iter"] = self.iter
|
data["iter"] = self.iter
|
||||||
super().save_checkpoint(step, data)
|
super().save_checkpoint(step, data)
|
||||||
|
|
||||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
|
||||||
completion = await self.server.chat_completion(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": question},
|
|
||||||
],
|
|
||||||
n=1,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
temperature=0.0,
|
|
||||||
split="eval",
|
|
||||||
)
|
|
||||||
gold_parsed = parse(
|
|
||||||
"\\boxed{" + answer + "}",
|
|
||||||
extraction_mode="first_match",
|
|
||||||
extraction_config=[LatexExtractionConfig()],
|
|
||||||
)
|
|
||||||
answer_parsed = parse(
|
|
||||||
completion.choices[0].message.content.split("</think>")[-1],
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
equations=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
# Ensures that boxed is tried first
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
score = 1 if verify(answer_parsed, gold_parsed) else 0
|
|
||||||
return score
|
|
||||||
|
|
||||||
async def rollout_and_score_eval_detailed(self, question: str, answer: str) -> dict:
|
|
||||||
"""Rollout and score evaluation with detailed sample data collection."""
|
"""Rollout and score evaluation with detailed sample data collection."""
|
||||||
completion = await self.server.chat_completion(
|
completion = await self.server.chat_completion(
|
||||||
messages=[
|
messages=[
|
||||||
|
|
@ -234,9 +196,7 @@ class GSM8kEnv(BaseEnv):
|
||||||
eval_tasks = []
|
eval_tasks = []
|
||||||
for item in self.test:
|
for item in self.test:
|
||||||
eval_tasks.append(
|
eval_tasks.append(
|
||||||
self.rollout_and_score_eval_detailed(
|
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||||
item["question"], item["gold_answer"]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue