add eval saving cli args

This commit is contained in:
hjc-puro 2025-07-09 03:12:13 +00:00
parent 72e75c2b13
commit a11af27298

View file

@ -155,6 +155,10 @@ class BaseEnvConfig(BaseModel):
default=None,
description="Path to save the groups, if set, will write groups to this jsonl",
)
data_dir_to_save_evals: Optional[str] = Field(
default=None,
description="Directory to save evaluation results",
)
min_items_sent_before_logging: int = Field(
default=2,
description="Minimum number of items sent before logging, if 0 or less, logs every time",
@ -637,6 +641,173 @@ class BaseEnv(ABC):
wandb_metrics.update(server_wandb_metrics)
wandb.log(wandb_metrics, step=self.curr_step)
async def evaluate_log(
self,
metrics: Dict,
task_name: Optional[str] = None,
model_name: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
generation_parameters: Optional[Dict] = None,
):
"""
Log evaluation results to a JSON file in the format expected by nous-evals.
Args:
metrics: Dictionary of metrics to log (same format as wandb_log)
task_name: Name of the evaluation task (defaults to env name)
model_name: Name of the model being evaluated
start_time: Start time of evaluation (unix timestamp)
end_time: End time of evaluation (unix timestamp)
generation_parameters: Dictionary of generation parameters used
"""
if self.config.data_dir_to_save_evals is None:
logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging")
return
import os
import json
from datetime import datetime
# Create directory if it doesn't exist
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3] + "-00-00"
filename = f"{timestamp}.json"
filepath = os.path.join(self.config.data_dir_to_save_evals, filename)
# Default values
if task_name is None:
if self.name:
task_name = f"{self.name}_eval"
else:
task_name = f"{self.__class__.__name__}_eval"
if model_name is None:
model_name = getattr(self.config, 'model_name', None)
if start_time is None:
start_time = time.time()
if end_time is None:
end_time = time.time()
if generation_parameters is None:
generation_parameters = {}
# Try to get generation parameters from config if not provided
config_gen_params = {}
if hasattr(self.config, 'max_token_length'):
config_gen_params['max_new_tokens'] = self.config.max_token_length
# Merge config params with passed params (passed params take precedence)
merged_gen_params = {**config_gen_params, **generation_parameters}
# Build the evaluation result structure
task_key = f"atropos|{task_name}|0"
eval_result = {
"config_general": {
"lighteval_sha": "atropos_framework",
"num_fewshot_seeds": 1,
"max_samples": None,
"job_id": "0",
"start_time": start_time,
"end_time": end_time,
"total_evaluation_time_secondes": str(end_time - start_time),
"model_name": model_name,
"model_sha": "",
"model_dtype": None,
"model_size": -1,
"generation_parameters": {
"early_stopping": None,
"repetition_penalty": None,
"frequency_penalty": None,
"length_penalty": None,
"presence_penalty": None,
"max_new_tokens": merged_gen_params.get("max_new_tokens", None),
"min_new_tokens": merged_gen_params.get("min_new_tokens", None),
"seed": merged_gen_params.get("seed", None),
"stop_tokens": merged_gen_params.get("stop_tokens", None),
"temperature": merged_gen_params.get("temperature", None),
"top_k": merged_gen_params.get("top_k", None),
"min_p": merged_gen_params.get("min_p", None),
"top_p": merged_gen_params.get("top_p", None),
"truncate_prompt": None,
"request_timeout": None,
"response_format": None,
**{k: v for k, v in merged_gen_params.items() if k not in [
'max_new_tokens', 'min_new_tokens', 'seed', 'stop_tokens',
'temperature', 'top_k', 'min_p', 'top_p'
]} # Include any other custom parameters
}
},
"results": {
task_key: metrics,
"all": metrics # For single task, "all" is the same as task-specific
},
"versions": {},
"config_tasks": {
task_key: {
"name": task_name,
"prompt_function": task_name,
"hf_repo": None,
"hf_subset": None,
"metrics": [], # Could be populated with metric definitions
"hf_revision": None,
"hf_filter": None,
"hf_avail_splits": [],
"trust_dataset": False,
"evaluation_splits": ["test"],
"few_shots_split": None,
"few_shots_select": None,
"generation_size": self.config.max_token_length,
"generation_grammar": None,
"stop_sequence": [],
"num_samples": None,
"suite": ["atropos"],
"original_num_docs": -1,
"effective_num_docs": -1,
"must_remove_duplicate_docs": False,
"num_fewshots": 0,
"truncate_fewshots": False,
"version": 1
}
},
"summary_tasks": {
task_key: {
"hashes": {
"hash_examples": "unknown",
"hash_full_prompts": "unknown",
"hash_input_tokens": "unknown",
"hash_cont_tokens": "unknown"
},
"truncated": 0,
"non_truncated": 0,
"padded": 0,
"non_padded": 0,
"effective_few_shots": 0,
"num_truncated_few_shots": 0
}
},
"summary_general": {
"hashes": {
"hash_examples": "unknown",
"hash_full_prompts": "unknown",
"hash_input_tokens": "unknown",
"hash_cont_tokens": "unknown"
},
"truncated": 0,
"non_truncated": 0,
"padded": 0,
"non_padded": 0,
"num_truncated_few_shots": 0
}
}
# Write to file
with open(filepath, 'w') as f:
json.dump(eval_result, f, indent=2)
print(f"Evaluation results saved to {filepath}")
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),