mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Eval script consolidation (#238)
The script now supports: - YAML and JSON configurations - Dataset-specific parameters - Overriding configuration via command line - Detailed logging and error handling
This commit is contained in:
parent
8a66d2a216
commit
850c1cf6f4
40 changed files with 1111 additions and 670 deletions
|
|
@ -1,25 +1,139 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
"""Configuration classes for the evaluation script"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS
|
||||
|
||||
|
||||
def is_valid_unix_filename(filename: str) -> bool:
|
||||
"""
|
||||
Check for shell-safe filenames.
|
||||
Only allows alphanumeric characters, hyphens, and underscores.
|
||||
"""
|
||||
if not filename:
|
||||
return False
|
||||
return bool(re.match(r"^[a-zA-Z0-9_-]+$", filename))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for a specific dataset"""
|
||||
|
||||
dataset: str
|
||||
size: int = 500
|
||||
seed: Optional[int] = None
|
||||
# Allow any additional dataset-specific parameters
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CategoryConfig:
|
||||
"""Configuration for a category of datasets"""
|
||||
|
||||
category: str
|
||||
datasets: list[DatasetConfig]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
category: str
|
||||
datasets: Union[str, list[str]]
|
||||
eval_dir: str
|
||||
dataset_size: int
|
||||
dataset_seed: int
|
||||
"""Global evaluation configuration"""
|
||||
|
||||
model: str
|
||||
provider: str
|
||||
developer_role: str = "system"
|
||||
developer_prompt: str = SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
provider: Optional[str] = None
|
||||
system_prompt: str = SYSTEM_PROMPTS["default"]
|
||||
system_role: str = "system"
|
||||
output_dir: str = "results"
|
||||
max_concurrent: int = 10
|
||||
default_size: int = 500
|
||||
default_seed: Optional[int] = None
|
||||
save_metadata: bool = False
|
||||
save_full_results: bool = False
|
||||
categories: list[CategoryConfig] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str):
|
||||
def from_json(cls, json_path: str) -> "EvalConfig":
|
||||
"""Load configuration from JSON file"""
|
||||
with open(json_path, "r") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
return cls._process_config_data(config_data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "EvalConfig":
|
||||
"""Load configuration from YAML file"""
|
||||
with open(yaml_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
return cls(**config)
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
return cls._process_config_data(config_data)
|
||||
|
||||
@classmethod
|
||||
def _process_config_data(cls, config_data: dict[str, Any]) -> "EvalConfig":
|
||||
"""Process configuration data from either JSON or YAML"""
|
||||
# Extract categories
|
||||
categories_data = config_data.pop("categories", [])
|
||||
categories = []
|
||||
|
||||
for category_data in categories_data:
|
||||
category_name = category_data.get("category")
|
||||
if not is_valid_unix_filename(category_name):
|
||||
raise ValueError(
|
||||
f"Invalid category name '{category_name}'. Category names must be valid Unix filenames."
|
||||
)
|
||||
|
||||
# Process datasets in this category
|
||||
datasets_data = category_data.get("datasets", [])
|
||||
datasets = []
|
||||
|
||||
for dataset_data in datasets_data:
|
||||
# If it's just a string, convert to dict with name
|
||||
if isinstance(dataset_data, str):
|
||||
dataset_data = {"name": dataset_data}
|
||||
|
||||
# Extract dataset name
|
||||
dataset_name = dataset_data.get("dataset")
|
||||
|
||||
# Extract size and seed with defaults
|
||||
size = dataset_data.get("size", config_data.get("default_size", 500))
|
||||
seed = dataset_data.get("seed", config_data.get("default_seed"))
|
||||
|
||||
# Extract all other parameters (everything except dataset, size, and seed)
|
||||
# If there's a nested 'params' dictionary, use its contents directly
|
||||
params = {}
|
||||
for k, v in dataset_data.items():
|
||||
if k not in ["dataset", "size", "seed"]:
|
||||
if k == "params" and isinstance(v, dict):
|
||||
# Flatten nested params dictionary
|
||||
params.update(v)
|
||||
else:
|
||||
params[k] = v
|
||||
|
||||
# Create dataset config
|
||||
dataset_config = DatasetConfig(
|
||||
dataset=dataset_name,
|
||||
size=size,
|
||||
seed=seed,
|
||||
params=params,
|
||||
)
|
||||
datasets.append(dataset_config)
|
||||
|
||||
# Create category config
|
||||
category_config = CategoryConfig(category=category_name, datasets=datasets)
|
||||
categories.append(category_config)
|
||||
|
||||
# Create main config
|
||||
return cls(
|
||||
model=config_data.get("model"),
|
||||
provider=config_data.get("provider", "openai"),
|
||||
system_prompt=config_data.get("system_prompt", SYSTEM_PROMPTS["default"]),
|
||||
system_role=config_data.get("system_role", "system"),
|
||||
output_dir=config_data.get("output_dir", "results"),
|
||||
max_concurrent=config_data.get("max_concurrent", 10),
|
||||
save_metadata=config_data.get("save_metadata", False),
|
||||
save_full_results=config_data.get("save_full_results", False),
|
||||
categories=categories,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue