linting errors

This commit is contained in:
hjc-puro 2025-07-11 00:29:57 +00:00
parent e601251893
commit da0d64ae89

View file

@ -654,7 +654,7 @@ class BaseEnv(ABC):
):
"""
Log evaluation results to a JSON file in the format expected by nous-evals.
Args:
metrics: Dictionary of metrics to log (same format as wandb_log)
task_name: Name of the evaluation task (defaults to env name)
@ -666,21 +666,23 @@ class BaseEnv(ABC):
verbose: If True, print a markdown table of the metrics
"""
if self.config.data_dir_to_save_evals is None:
logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging")
logger.warning(
"data_dir_to_save_evals is not set, skipping evaluation logging"
)
return
import os
import json
import os
import jsonlines
from datetime import datetime
# Create directory if it doesn't exist
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
# Generate filename
filename = "metrics.json"
filepath = os.path.join(self.config.data_dir_to_save_evals, filename)
# Default values
if task_name is None:
if self.name:
@ -689,11 +691,13 @@ class BaseEnv(ABC):
task_name = f"{self.__class__.__name__}_eval"
if model_name is None:
# Try to get model name from config first, then from server configs
model_name = getattr(self.config, 'model_name', None)
if model_name is None and hasattr(self, 'server') and self.server.servers:
model_name = getattr(self.config, "model_name", None)
if model_name is None and hasattr(self, "server") and self.server.servers:
# Get model name from first server config
first_server = self.server.servers[0]
if hasattr(first_server, 'config') and hasattr(first_server.config, 'model_name'):
if hasattr(first_server, "config") and hasattr(
first_server.config, "model_name"
):
model_name = first_server.config.model_name
if start_time is None:
start_time = time.time()
@ -701,36 +705,48 @@ class BaseEnv(ABC):
end_time = time.time()
if generation_parameters is None:
generation_parameters = {}
# Try to get generation parameters from config if not provided
config_gen_params = {}
if hasattr(self.config, 'max_token_length'):
config_gen_params['max_new_tokens'] = self.config.max_token_length
if hasattr(self.config, "max_token_length"):
config_gen_params["max_new_tokens"] = self.config.max_token_length
# Merge config params with passed params (passed params take precedence)
merged_gen_params = {**config_gen_params, **generation_parameters}
# Print metrics table if verbose
if verbose:
print("\n" + "="*60)
print("\n" + "=" * 60)
print(f"Evaluation Results: {task_name}")
print("="*60)
print(f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|")
print(f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|")
print("=" * 60)
header = (
f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|"
f"{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|"
)
print(header)
print(
f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|"
)
# Main task row
for metric_name, metric_value in metrics.items():
clean_metric_name = metric_name.replace("eval/", "").replace("_", " ")
direction = "" if "correct" in metric_name or "acc" in metric_name else " "
print(f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|")
print("="*60)
direction = (
"" if "correct" in metric_name or "acc" in metric_name else " "
)
row = (
f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|"
f"{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|"
)
print(row)
print("=" * 60)
print(f"Evaluation completed in {end_time - start_time:.2f} seconds")
print("="*60 + "\n")
print("=" * 60 + "\n")
# Build the evaluation result structure
task_key = f"atropos|{task_name}|0"
eval_result = {
"config_general": {
"lighteval_sha": "atropos_framework",
@ -761,15 +777,26 @@ class BaseEnv(ABC):
"truncate_prompt": None,
"request_timeout": None,
"response_format": None,
**{k: v for k, v in merged_gen_params.items() if k not in [
'max_new_tokens', 'min_new_tokens', 'seed', 'stop_tokens',
'temperature', 'top_k', 'min_p', 'top_p'
]} # Include any other custom parameters
}
**{
k: v
for k, v in merged_gen_params.items()
if k
not in [
"max_new_tokens",
"min_new_tokens",
"seed",
"stop_tokens",
"temperature",
"top_k",
"min_p",
"top_p",
]
}, # Include any other custom parameters
},
},
"results": {
task_key: metrics,
"all": metrics # For single task, "all" is the same as task-specific
"all": metrics, # For single task, "all" is the same as task-specific
},
"versions": {},
"config_tasks": {
@ -796,7 +823,7 @@ class BaseEnv(ABC):
"must_remove_duplicate_docs": False,
"num_fewshots": 0,
"truncate_fewshots": False,
"version": 1
"version": 1,
}
},
"summary_tasks": {
@ -805,14 +832,14 @@ class BaseEnv(ABC):
"hash_examples": "unknown",
"hash_full_prompts": "unknown",
"hash_input_tokens": "unknown",
"hash_cont_tokens": "unknown"
"hash_cont_tokens": "unknown",
},
"truncated": 0,
"non_truncated": 0,
"padded": 0,
"non_padded": 0,
"effective_few_shots": 0,
"num_truncated_few_shots": 0
"num_truncated_few_shots": 0,
}
},
"summary_general": {
@ -820,26 +847,28 @@ class BaseEnv(ABC):
"hash_examples": "unknown",
"hash_full_prompts": "unknown",
"hash_input_tokens": "unknown",
"hash_cont_tokens": "unknown"
"hash_cont_tokens": "unknown",
},
"truncated": 0,
"non_truncated": 0,
"padded": 0,
"non_padded": 0,
"num_truncated_few_shots": 0
}
"num_truncated_few_shots": 0,
},
}
# Write main results to JSON file
with open(filepath, 'w') as f:
with open(filepath, "w") as f:
json.dump(eval_result, f, indent=2)
print(f"Evaluation results saved to {filepath}")
# Write samples to JSONL file if provided
if samples:
samples_filepath = os.path.join(self.config.data_dir_to_save_evals, "samples.jsonl")
with jsonlines.open(samples_filepath, 'w') as writer:
samples_filepath = os.path.join(
self.config.data_dir_to_save_evals, "samples.jsonl"
)
with jsonlines.open(samples_filepath, "w") as writer:
for sample in samples:
writer.write(sample)
print(f"Evaluation samples saved to {samples_filepath}")