narrow down scope further

This commit is contained in:
Jai Suphavadeeprasit 2026-02-27 13:15:23 -05:00
parent f343b24a6a
commit 836c346406
3 changed files with 22 additions and 20 deletions

View file

@ -11,7 +11,6 @@ from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
@ -124,6 +123,8 @@ class MathEnv(BaseEnv):
slurm=True,
testing=False,
):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
@ -396,6 +397,7 @@ class MathEnv(BaseEnv):
)
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
self.normal_rollouts.pop(0)
print(f"Collected {len(to_postprocess['scores'])} trajectories")
return to_postprocess, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -480,7 +482,6 @@ class MathEnv(BaseEnv):
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
]
)
return scores
async def get_next_item(self):
@ -496,7 +497,10 @@ class MathEnv(BaseEnv):
)
break
except TypeError:
continue
print(
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
return (prompt, answer, "normal")