mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
narrow down scope further
This commit is contained in:
parent
f343b24a6a
commit
836c346406
3 changed files with 22 additions and 20 deletions
|
|
@ -11,7 +11,6 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from math_verify.errors import TimeoutException
|
||||
|
|
@ -124,6 +123,8 @@ class MathEnv(BaseEnv):
|
|||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
print("Initializing MathEnv")
|
||||
print(f"Slurm: {slurm}, Testing: {testing}")
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
|
|
@ -396,6 +397,7 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
|
||||
self.normal_rollouts.pop(0)
|
||||
print(f"Collected {len(to_postprocess['scores'])} trajectories")
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -480,7 +482,6 @@ class MathEnv(BaseEnv):
|
|||
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
|
||||
]
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self):
|
||||
|
|
@ -496,7 +497,10 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
break
|
||||
except TypeError:
|
||||
continue
|
||||
print(
|
||||
f"Error in getting next item, trying again, "
|
||||
f"data: {next_item['question']} -> {next_item['final_answer']}"
|
||||
)
|
||||
return (prompt, answer, "normal")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue