diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 6abf13d8..ff332580 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -155,6 +155,10 @@ class BaseEnvConfig(BaseModel): default=None, description="Path to save the groups, if set, will write groups to this jsonl", ) + data_dir_to_save_evals: Optional[str] = Field( + default=None, + description="Directory to save evaluation results", + ) min_items_sent_before_logging: int = Field( default=2, description="Minimum number of items sent before logging, if 0 or less, logs every time", @@ -637,6 +641,173 @@ class BaseEnv(ABC): wandb_metrics.update(server_wandb_metrics) wandb.log(wandb_metrics, step=self.curr_step) + async def evaluate_log( + self, + metrics: Dict, + task_name: Optional[str] = None, + model_name: Optional[str] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + generation_parameters: Optional[Dict] = None, + ): + """ + Log evaluation results to a JSON file in the format expected by nous-evals. + + Args: + metrics: Dictionary of metrics to log (same format as wandb_log) + task_name: Name of the evaluation task (defaults to env name) + model_name: Name of the model being evaluated + start_time: Start time of evaluation (unix timestamp) + end_time: End time of evaluation (unix timestamp) + generation_parameters: Dictionary of generation parameters used + """ + if self.config.data_dir_to_save_evals is None: + logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging") + return + + import os + import json + from datetime import datetime + + # Create directory if it doesn't exist + os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True) + + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3] + "-00-00" + filename = f"{timestamp}.json" + filepath = os.path.join(self.config.data_dir_to_save_evals, filename) + + # Default values + if task_name is None: + if self.name: + task_name = f"{self.name}_eval" + else: + task_name = f"{self.__class__.__name__}_eval" + if model_name is None: + model_name = getattr(self.config, 'model_name', None) + if start_time is None: + start_time = time.time() + if end_time is None: + end_time = time.time() + if generation_parameters is None: + generation_parameters = {} + + # Try to get generation parameters from config if not provided + config_gen_params = {} + if hasattr(self.config, 'max_token_length'): + config_gen_params['max_new_tokens'] = self.config.max_token_length + + # Merge config params with passed params (passed params take precedence) + merged_gen_params = {**config_gen_params, **generation_parameters} + + # Build the evaluation result structure + task_key = f"atropos|{task_name}|0" + + eval_result = { + "config_general": { + "lighteval_sha": "atropos_framework", + "num_fewshot_seeds": 1, + "max_samples": None, + "job_id": "0", + "start_time": start_time, + "end_time": end_time, + "total_evaluation_time_secondes": str(end_time - start_time), + "model_name": model_name, + "model_sha": "", + "model_dtype": None, + "model_size": -1, + "generation_parameters": { + "early_stopping": None, + "repetition_penalty": None, + "frequency_penalty": None, + "length_penalty": None, + "presence_penalty": None, + "max_new_tokens": merged_gen_params.get("max_new_tokens", None), + "min_new_tokens": merged_gen_params.get("min_new_tokens", None), + "seed": merged_gen_params.get("seed", None), + "stop_tokens": merged_gen_params.get("stop_tokens", None), + "temperature": merged_gen_params.get("temperature", None), + "top_k": merged_gen_params.get("top_k", None), + "min_p": merged_gen_params.get("min_p", None), + "top_p": merged_gen_params.get("top_p", None), + "truncate_prompt": None, + "request_timeout": None, + "response_format": None, + **{k: v for k, v in merged_gen_params.items() if k not in [ + 'max_new_tokens', 'min_new_tokens', 'seed', 'stop_tokens', + 'temperature', 'top_k', 'min_p', 'top_p' + ]} # Include any other custom parameters + } + }, + "results": { + task_key: metrics, + "all": metrics # For single task, "all" is the same as task-specific + }, + "versions": {}, + "config_tasks": { + task_key: { + "name": task_name, + "prompt_function": task_name, + "hf_repo": None, + "hf_subset": None, + "metrics": [], # Could be populated with metric definitions + "hf_revision": None, + "hf_filter": None, + "hf_avail_splits": [], + "trust_dataset": False, + "evaluation_splits": ["test"], + "few_shots_split": None, + "few_shots_select": None, + "generation_size": self.config.max_token_length, + "generation_grammar": None, + "stop_sequence": [], + "num_samples": None, + "suite": ["atropos"], + "original_num_docs": -1, + "effective_num_docs": -1, + "must_remove_duplicate_docs": False, + "num_fewshots": 0, + "truncate_fewshots": False, + "version": 1 + } + }, + "summary_tasks": { + task_key: { + "hashes": { + "hash_examples": "unknown", + "hash_full_prompts": "unknown", + "hash_input_tokens": "unknown", + "hash_cont_tokens": "unknown" + }, + "truncated": 0, + "non_truncated": 0, + "padded": 0, + "non_padded": 0, + "effective_few_shots": 0, + "num_truncated_few_shots": 0 + } + }, + "summary_general": { + "hashes": { + "hash_examples": "unknown", + "hash_full_prompts": "unknown", + "hash_input_tokens": "unknown", + "hash_cont_tokens": "unknown" + }, + "truncated": 0, + "non_truncated": 0, + "padded": 0, + "non_padded": 0, + "num_truncated_few_shots": 0 + } + } + + # Write to file + with open(filepath, 'w') as f: + json.dump(eval_result, f, indent=2) + + print(f"Evaluation results saved to {filepath}") + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10),