add printing

This commit is contained in:
hjc-puro 2025-07-09 23:35:26 +00:00
parent a11af27298
commit f4de3ad6f5
2 changed files with 61 additions and 9 deletions

View file

@ -159,13 +159,39 @@ class GSM8kEnv(BaseEnv):
return score
async def evaluate(self, *args, **kwargs):
import time
start_time = time.time()
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["question"], item["gold_answer"])
)
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
percent_correct = sum(scores) / len(scores)
end_time = time.time()
# Add to existing metrics for wandb
self.eval_metrics.append(("eval/percent_correct", percent_correct))
# Log evaluation results
eval_metrics = {
"eval/percent_correct": percent_correct,
"eval/total_samples": len(scores),
"eval/correct_samples": sum(scores),
}
await self.evaluate_log(
metrics=eval_metrics,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.0,
"max_tokens": self.config.max_token_length,
"split": "eval"
}
)
async def collect_trajectories(
self, item: GSM8kRow