mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add printing
This commit is contained in:
parent
a11af27298
commit
f4de3ad6f5
2 changed files with 61 additions and 9 deletions
|
|
@ -159,13 +159,39 @@ class GSM8kEnv(BaseEnv):
|
|||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
percent_correct = sum(scores) / len(scores)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Add to existing metrics for wandb
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
# Log evaluation results
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
"eval/total_samples": len(scores),
|
||||
"eval/correct_samples": sum(scores),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"split": "eval"
|
||||
}
|
||||
)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: GSM8kRow
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue