diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index ff332580..71f0e2b8 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -155,10 +155,6 @@ class BaseEnvConfig(BaseModel): default=None, description="Path to save the groups, if set, will write groups to this jsonl", ) - data_dir_to_save_evals: Optional[str] = Field( - default=None, - description="Directory to save evaluation results", - ) min_items_sent_before_logging: int = Field( default=2, description="Minimum number of items sent before logging, if 0 or less, logs every time", @@ -649,6 +645,8 @@ class BaseEnv(ABC): start_time: Optional[float] = None, end_time: Optional[float] = None, generation_parameters: Optional[Dict] = None, + samples: Optional[List[Dict]] = None, + verbose: bool = True, ): """ Log evaluation results to a JSON file in the format expected by nous-evals. @@ -660,6 +658,8 @@ class BaseEnv(ABC): start_time: Start time of evaluation (unix timestamp) end_time: End time of evaluation (unix timestamp) generation_parameters: Dictionary of generation parameters used + samples: List of sample dictionaries to save to samples.jsonl + verbose: If True, print a markdown table of the metrics """ if self.config.data_dir_to_save_evals is None: logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging") @@ -667,14 +667,14 @@ class BaseEnv(ABC): import os import json + import jsonlines from datetime import datetime # Create directory if it doesn't exist os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True) - # Generate filename with timestamp - timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3] + "-00-00" - filename = f"{timestamp}.json" + # Generate filename + filename = "metrics.json" filepath = os.path.join(self.config.data_dir_to_save_evals, filename) # Default values @@ -700,6 +700,24 @@ class BaseEnv(ABC): # Merge config params with passed params (passed params take precedence) merged_gen_params = {**config_gen_params, **generation_parameters} + # Print metrics table if verbose + if verbose: + print("\n" + "="*60) + print(f"Evaluation Results: {task_name}") + print("="*60) + print(f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|") + print(f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|") + + # Main task row + for metric_name, metric_value in metrics.items(): + clean_metric_name = metric_name.replace("eval/", "").replace("_", " ") + direction = "↑" if "correct" in metric_name or "acc" in metric_name else " " + print(f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|") + + print("="*60) + print(f"Evaluation completed in {end_time - start_time:.2f} seconds") + print("="*60 + "\n") + # Build the evaluation result structure task_key = f"atropos|{task_name}|0" @@ -802,11 +820,19 @@ class BaseEnv(ABC): } } - # Write to file + # Write main results to JSON file with open(filepath, 'w') as f: json.dump(eval_result, f, indent=2) print(f"Evaluation results saved to {filepath}") + + # Write samples to JSONL file if provided + if samples: + samples_filepath = os.path.join(self.config.data_dir_to_save_evals, "samples.jsonl") + with jsonlines.open(samples_filepath, 'w') as writer: + for sample in samples: + writer.write(sample) + print(f"Evaluation samples saved to {samples_filepath}") @retry( stop=stop_after_attempt(3), diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 4fd37707..d273cc62 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -159,13 +159,39 @@ class GSM8kEnv(BaseEnv): return score async def evaluate(self, *args, **kwargs): + import time + start_time = time.time() + eval_tasks = [] for item in self.test: eval_tasks.append( self.rollout_and_score_eval(item["question"], item["gold_answer"]) ) scores = await tqdm_asyncio.gather(*eval_tasks) - self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + percent_correct = sum(scores) / len(scores) + + end_time = time.time() + + # Add to existing metrics for wandb + self.eval_metrics.append(("eval/percent_correct", percent_correct)) + + # Log evaluation results + eval_metrics = { + "eval/percent_correct": percent_correct, + "eval/total_samples": len(scores), + "eval/correct_samples": sum(scores), + } + + await self.evaluate_log( + metrics=eval_metrics, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": 0.0, + "max_tokens": self.config.max_token_length, + "split": "eval" + } + ) async def collect_trajectories( self, item: GSM8kRow