mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
linting errors
This commit is contained in:
parent
e601251893
commit
da0d64ae89
1 changed files with 75 additions and 46 deletions
|
|
@ -654,7 +654,7 @@ class BaseEnv(ABC):
|
|||
):
|
||||
"""
|
||||
Log evaluation results to a JSON file in the format expected by nous-evals.
|
||||
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metrics to log (same format as wandb_log)
|
||||
task_name: Name of the evaluation task (defaults to env name)
|
||||
|
|
@ -666,21 +666,23 @@ class BaseEnv(ABC):
|
|||
verbose: If True, print a markdown table of the metrics
|
||||
"""
|
||||
if self.config.data_dir_to_save_evals is None:
|
||||
logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging")
|
||||
logger.warning(
|
||||
"data_dir_to_save_evals is not set, skipping evaluation logging"
|
||||
)
|
||||
return
|
||||
|
||||
import os
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
|
||||
|
||||
|
||||
# Generate filename
|
||||
filename = "metrics.json"
|
||||
filepath = os.path.join(self.config.data_dir_to_save_evals, filename)
|
||||
|
||||
|
||||
# Default values
|
||||
if task_name is None:
|
||||
if self.name:
|
||||
|
|
@ -689,11 +691,13 @@ class BaseEnv(ABC):
|
|||
task_name = f"{self.__class__.__name__}_eval"
|
||||
if model_name is None:
|
||||
# Try to get model name from config first, then from server configs
|
||||
model_name = getattr(self.config, 'model_name', None)
|
||||
if model_name is None and hasattr(self, 'server') and self.server.servers:
|
||||
model_name = getattr(self.config, "model_name", None)
|
||||
if model_name is None and hasattr(self, "server") and self.server.servers:
|
||||
# Get model name from first server config
|
||||
first_server = self.server.servers[0]
|
||||
if hasattr(first_server, 'config') and hasattr(first_server.config, 'model_name'):
|
||||
if hasattr(first_server, "config") and hasattr(
|
||||
first_server.config, "model_name"
|
||||
):
|
||||
model_name = first_server.config.model_name
|
||||
if start_time is None:
|
||||
start_time = time.time()
|
||||
|
|
@ -701,36 +705,48 @@ class BaseEnv(ABC):
|
|||
end_time = time.time()
|
||||
if generation_parameters is None:
|
||||
generation_parameters = {}
|
||||
|
||||
|
||||
# Try to get generation parameters from config if not provided
|
||||
config_gen_params = {}
|
||||
if hasattr(self.config, 'max_token_length'):
|
||||
config_gen_params['max_new_tokens'] = self.config.max_token_length
|
||||
|
||||
if hasattr(self.config, "max_token_length"):
|
||||
config_gen_params["max_new_tokens"] = self.config.max_token_length
|
||||
|
||||
# Merge config params with passed params (passed params take precedence)
|
||||
merged_gen_params = {**config_gen_params, **generation_parameters}
|
||||
|
||||
|
||||
# Print metrics table if verbose
|
||||
if verbose:
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Evaluation Results: {task_name}")
|
||||
print("="*60)
|
||||
print(f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|")
|
||||
print(f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|")
|
||||
|
||||
print("=" * 60)
|
||||
header = (
|
||||
f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|"
|
||||
f"{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|"
|
||||
)
|
||||
print(header)
|
||||
print(
|
||||
f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|"
|
||||
)
|
||||
|
||||
# Main task row
|
||||
for metric_name, metric_value in metrics.items():
|
||||
clean_metric_name = metric_name.replace("eval/", "").replace("_", " ")
|
||||
direction = "↑" if "correct" in metric_name or "acc" in metric_name else " "
|
||||
print(f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|")
|
||||
|
||||
print("="*60)
|
||||
direction = (
|
||||
"↑" if "correct" in metric_name or "acc" in metric_name else " "
|
||||
)
|
||||
row = (
|
||||
f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|"
|
||||
f"{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|"
|
||||
)
|
||||
print(row)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Evaluation completed in {end_time - start_time:.2f} seconds")
|
||||
print("="*60 + "\n")
|
||||
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Build the evaluation result structure
|
||||
task_key = f"atropos|{task_name}|0"
|
||||
|
||||
|
||||
eval_result = {
|
||||
"config_general": {
|
||||
"lighteval_sha": "atropos_framework",
|
||||
|
|
@ -761,15 +777,26 @@ class BaseEnv(ABC):
|
|||
"truncate_prompt": None,
|
||||
"request_timeout": None,
|
||||
"response_format": None,
|
||||
**{k: v for k, v in merged_gen_params.items() if k not in [
|
||||
'max_new_tokens', 'min_new_tokens', 'seed', 'stop_tokens',
|
||||
'temperature', 'top_k', 'min_p', 'top_p'
|
||||
]} # Include any other custom parameters
|
||||
}
|
||||
**{
|
||||
k: v
|
||||
for k, v in merged_gen_params.items()
|
||||
if k
|
||||
not in [
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"seed",
|
||||
"stop_tokens",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"top_p",
|
||||
]
|
||||
}, # Include any other custom parameters
|
||||
},
|
||||
},
|
||||
"results": {
|
||||
task_key: metrics,
|
||||
"all": metrics # For single task, "all" is the same as task-specific
|
||||
"all": metrics, # For single task, "all" is the same as task-specific
|
||||
},
|
||||
"versions": {},
|
||||
"config_tasks": {
|
||||
|
|
@ -796,7 +823,7 @@ class BaseEnv(ABC):
|
|||
"must_remove_duplicate_docs": False,
|
||||
"num_fewshots": 0,
|
||||
"truncate_fewshots": False,
|
||||
"version": 1
|
||||
"version": 1,
|
||||
}
|
||||
},
|
||||
"summary_tasks": {
|
||||
|
|
@ -805,14 +832,14 @@ class BaseEnv(ABC):
|
|||
"hash_examples": "unknown",
|
||||
"hash_full_prompts": "unknown",
|
||||
"hash_input_tokens": "unknown",
|
||||
"hash_cont_tokens": "unknown"
|
||||
"hash_cont_tokens": "unknown",
|
||||
},
|
||||
"truncated": 0,
|
||||
"non_truncated": 0,
|
||||
"padded": 0,
|
||||
"non_padded": 0,
|
||||
"effective_few_shots": 0,
|
||||
"num_truncated_few_shots": 0
|
||||
"num_truncated_few_shots": 0,
|
||||
}
|
||||
},
|
||||
"summary_general": {
|
||||
|
|
@ -820,26 +847,28 @@ class BaseEnv(ABC):
|
|||
"hash_examples": "unknown",
|
||||
"hash_full_prompts": "unknown",
|
||||
"hash_input_tokens": "unknown",
|
||||
"hash_cont_tokens": "unknown"
|
||||
"hash_cont_tokens": "unknown",
|
||||
},
|
||||
"truncated": 0,
|
||||
"non_truncated": 0,
|
||||
"padded": 0,
|
||||
"non_padded": 0,
|
||||
"num_truncated_few_shots": 0
|
||||
}
|
||||
"num_truncated_few_shots": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Write main results to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(eval_result, f, indent=2)
|
||||
|
||||
|
||||
print(f"Evaluation results saved to {filepath}")
|
||||
|
||||
|
||||
# Write samples to JSONL file if provided
|
||||
if samples:
|
||||
samples_filepath = os.path.join(self.config.data_dir_to_save_evals, "samples.jsonl")
|
||||
with jsonlines.open(samples_filepath, 'w') as writer:
|
||||
samples_filepath = os.path.join(
|
||||
self.config.data_dir_to_save_evals, "samples.jsonl"
|
||||
)
|
||||
with jsonlines.open(samples_filepath, "w") as writer:
|
||||
for sample in samples:
|
||||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue