diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 7406ef2f..11499cad 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -798,6 +798,23 @@ class BaseEnv(ABC): except Exception as e: logger.warning("Failed to generate eval HTML viewer: %s", e) + def log_eval_sample(self, sample): + """Stream-write a single eval sample to samples.jsonl. + + Lazy-initializes the writer on first call. Use this inside evaluate() + to write samples as they complete rather than batching at the end. + If using this, omit the samples= parameter from evaluate_log(). + """ + if self._eval_sample_writer is None: + if self.config.data_dir_to_save_evals is None: + return + os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True) + self._eval_samples_path = os.path.join( + self.config.data_dir_to_save_evals, "samples.jsonl" + ) + self._eval_sample_writer = jsonlines.open(self._eval_samples_path, "w") + self._eval_sample_writer.write(sample) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10), @@ -1343,10 +1360,23 @@ class BaseEnv(ABC): """ Internal method to run evaluation with proper setup. """ + self._eval_sample_writer = None + self._eval_samples_path = None await self.setup() try: await self.evaluate() finally: + # Close streaming eval sample writer if it was used + if self._eval_sample_writer is not None: + self._eval_sample_writer.close() + if self._eval_samples_path: + try: + from atroposlib.frontend.jsonl2html import generate_eval_html + + generate_eval_html(self._eval_samples_path) + except Exception as e: + logger.warning("Failed to generate eval HTML: %s", e) + # Close JSONL trajectory writer if it was used if self.jsonl_writer is not None: self.jsonl_writer.close() if self.config.data_path_to_save_groups: diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 87823526..09eca7c5 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -189,17 +189,18 @@ class GSM8kEnv(BaseEnv): async def evaluate(self, *args, **kwargs): start_time = time.time() - eval_tasks = [] - for item in self.test: - eval_tasks.append( - self.rollout_and_score_eval(item["question"], item["gold_answer"]) + async def rollout_and_log(item): + result = await self.rollout_and_score_eval( + item["question"], item["gold_answer"] ) + if result is not None: + self.log_eval_sample(result.get("sample", result)) + return result + + eval_tasks = [rollout_and_log(item) for item in self.test] results = await tqdm_asyncio.gather(*eval_tasks) - # Extract scores and samples scores = [result["score"] for result in results] - samples = [result["sample"] for result in results] - percent_correct = sum(scores) / len(scores) end_time = time.time() @@ -207,14 +208,8 @@ class GSM8kEnv(BaseEnv): # Add to existing metrics for wandb self.eval_metrics.append(("eval/percent_correct", percent_correct)) - # Log evaluation results - eval_metrics = { - "eval/percent_correct": percent_correct, - } - await self.evaluate_log( - metrics=eval_metrics, - samples=samples, + metrics={"eval/percent_correct": percent_correct}, start_time=start_time, end_time=end_time, generation_parameters={