Eval script consolidation (#238)

The script now supports:
   - YAML and JSON configurations
   - Dataset-specific parameters
   - Overriding configuration via command line
   - Detailed logging and error handling
This commit is contained in:
Andreas Köpf 2025-02-27 17:39:14 +01:00 committed by GitHub
parent 8a66d2a216
commit 850c1cf6f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 1111 additions and 670 deletions

View file

@ -1,25 +1,139 @@
from dataclasses import dataclass
from typing import Union
"""Configuration classes for the evaluation script"""
import json
import re
from dataclasses import dataclass, field
from typing import Any, Optional
import yaml
from reasoning_gym.utils import SYSTEM_PROMPTS
def is_valid_unix_filename(filename: str) -> bool:
"""
Check for shell-safe filenames.
Only allows alphanumeric characters, hyphens, and underscores.
"""
if not filename:
return False
return bool(re.match(r"^[a-zA-Z0-9_-]+$", filename))
@dataclass
class DatasetConfig:
"""Configuration for a specific dataset"""
dataset: str
size: int = 500
seed: Optional[int] = None
# Allow any additional dataset-specific parameters
params: dict[str, Any] = field(default_factory=dict)
@dataclass
class CategoryConfig:
"""Configuration for a category of datasets"""
category: str
datasets: list[DatasetConfig]
@dataclass
class EvalConfig:
category: str
datasets: Union[str, list[str]]
eval_dir: str
dataset_size: int
dataset_seed: int
"""Global evaluation configuration"""
model: str
provider: str
developer_role: str = "system"
developer_prompt: str = SYSTEM_PROMPTS["DeepSeekZero"]
provider: Optional[str] = None
system_prompt: str = SYSTEM_PROMPTS["default"]
system_role: str = "system"
output_dir: str = "results"
max_concurrent: int = 10
default_size: int = 500
default_seed: Optional[int] = None
save_metadata: bool = False
save_full_results: bool = False
categories: list[CategoryConfig] = field(default_factory=list)
@classmethod
def from_yaml(cls, yaml_path: str):
def from_json(cls, json_path: str) -> "EvalConfig":
"""Load configuration from JSON file"""
with open(json_path, "r") as f:
config_data = json.load(f)
return cls._process_config_data(config_data)
@classmethod
def from_yaml(cls, yaml_path: str) -> "EvalConfig":
"""Load configuration from YAML file"""
with open(yaml_path, "r") as f:
config = yaml.safe_load(f)
return cls(**config)
config_data = yaml.safe_load(f)
return cls._process_config_data(config_data)
@classmethod
def _process_config_data(cls, config_data: dict[str, Any]) -> "EvalConfig":
"""Process configuration data from either JSON or YAML"""
# Extract categories
categories_data = config_data.pop("categories", [])
categories = []
for category_data in categories_data:
category_name = category_data.get("category")
if not is_valid_unix_filename(category_name):
raise ValueError(
f"Invalid category name '{category_name}'. Category names must be valid Unix filenames."
)
# Process datasets in this category
datasets_data = category_data.get("datasets", [])
datasets = []
for dataset_data in datasets_data:
# If it's just a string, convert to dict with name
if isinstance(dataset_data, str):
dataset_data = {"name": dataset_data}
# Extract dataset name
dataset_name = dataset_data.get("dataset")
# Extract size and seed with defaults
size = dataset_data.get("size", config_data.get("default_size", 500))
seed = dataset_data.get("seed", config_data.get("default_seed"))
# Extract all other parameters (everything except dataset, size, and seed)
# If there's a nested 'params' dictionary, use its contents directly
params = {}
for k, v in dataset_data.items():
if k not in ["dataset", "size", "seed"]:
if k == "params" and isinstance(v, dict):
# Flatten nested params dictionary
params.update(v)
else:
params[k] = v
# Create dataset config
dataset_config = DatasetConfig(
dataset=dataset_name,
size=size,
seed=seed,
params=params,
)
datasets.append(dataset_config)
# Create category config
category_config = CategoryConfig(category=category_name, datasets=datasets)
categories.append(category_config)
# Create main config
return cls(
model=config_data.get("model"),
provider=config_data.get("provider", "openai"),
system_prompt=config_data.get("system_prompt", SYSTEM_PROMPTS["default"]),
system_role=config_data.get("system_role", "system"),
output_dir=config_data.get("output_dir", "results"),
max_concurrent=config_data.get("max_concurrent", 10),
save_metadata=config_data.get("save_metadata", False),
save_full_results=config_data.get("save_full_results", False),
categories=categories,
)