mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge pull request #109 from joesharratt1229/feat/r1-evals
added r1 evaluation logic
This commit is contained in:
commit
d56d2e25fb
7 changed files with 224 additions and 0 deletions
0
eval/r1/__init__.py
Normal file
0
eval/r1/__init__.py
Normal file
139
eval/r1/eval.py
Normal file
139
eval/r1/eval.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
from eval_config import EvalConfig
|
||||
from requests.exceptions import RequestException
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
|
||||
class OpenRouterEvaluator:
|
||||
def __init__(self, model: str, config: EvalConfig):
|
||||
self.logger = logging.getLogger(f"OpenRouterEvaluator.{model}")
|
||||
self.config = config
|
||||
self.output_dir = f"{config.eval_dir}/{config.category}"
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
self.model = model
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"HTTP-Referer": os.getenv("OR_SITE_URL", "localhost"),
|
||||
"X-Title": os.getenv("OR_APP_NAME", "Model Evaluation"),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def save_results(self, results: List[Dict[str, Any]], dataset, dataset_name) -> Dict[str, Any]:
|
||||
|
||||
file_name = f"{self.output_dir}/{dataset_name}.json"
|
||||
total_score = sum(r["score"] for r in results)
|
||||
|
||||
metrics = {
|
||||
"dataset_name": dataset_name,
|
||||
"model": self.model,
|
||||
"size": dataset.size,
|
||||
"provider": self.config.provider,
|
||||
"average_score": total_score / len(results) if results else 0,
|
||||
"total_examples": len(results),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"config": asdict(dataset.config),
|
||||
"results": results, # save results to allow for performance recalculation
|
||||
}
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
return metrics
|
||||
|
||||
def prepare_messages(self, prompt: str) -> List[Dict[str, str]]:
|
||||
messages = [
|
||||
{"role": self.config.developer_role, "content": self.config.developer_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"provider": {"order": ["Nebius"], "allow_fallbacks": False},
|
||||
} # make sure only one provider is used
|
||||
|
||||
return payload
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(RequestException),
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
def get_model_response(self, prompt: str) -> str:
|
||||
"""Get response from the model via OpenRouter API."""
|
||||
|
||||
payload = self.prepare_messages(prompt)
|
||||
try:
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RequestException(
|
||||
f"API request failed: {str(e)}", {"endpoint": self.base_url, "model": self.model}
|
||||
) from e
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
def evaluate_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""Evaluate model on multiple datasets with their respective configurations."""
|
||||
all_results = []
|
||||
|
||||
for dataset_name in self.config.datasets:
|
||||
self.logger.info(f"\nEvaluating dataset: {dataset_name}")
|
||||
|
||||
# Create dataset with its specific configuration
|
||||
dataset = reasoning_gym.create_dataset(
|
||||
dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed
|
||||
)
|
||||
results = []
|
||||
|
||||
for i, entry in enumerate(dataset):
|
||||
print(f"On example {i+1} of {len(dataset)}")
|
||||
response = self.get_model_response(entry["question"])
|
||||
model_answer = extract_answer(response)
|
||||
|
||||
score = dataset.score_answer(answer=model_answer, entry=entry)
|
||||
|
||||
result = {
|
||||
"question": entry["question"],
|
||||
"expected_answer": str(entry["answer"]),
|
||||
"model_answer": model_answer,
|
||||
"score": score,
|
||||
"metadata": str(entry["metadata"]),
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
metrics = self.save_results(results, dataset, dataset_name)
|
||||
|
||||
all_results.append({"metrics": metrics, "results": results})
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets")
|
||||
parser.add_argument("--yaml", required=True, help="Path to YAML configuration file")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = EvalConfig.from_yaml(args.yaml)
|
||||
output_dir = f"{config.eval_dir}/{config.category}"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
evaluator = OpenRouterEvaluator(model=config.model, config=config)
|
||||
all_results = evaluator.evaluate_datasets()
|
||||
|
||||
with open(f"{output_dir}/summary.json", "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
eval/r1/eval_config.py
Normal file
25
eval/r1/eval_config.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
category: str
|
||||
datasets: Union[str, List[str]]
|
||||
eval_dir: str
|
||||
dataset_size: int
|
||||
dataset_seed: int
|
||||
model: str = "deepseek/deepseek-r1"
|
||||
provider: str = "Nebius"
|
||||
developer_role: str = "system"
|
||||
developer_prompt: str = SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str):
|
||||
with open(yaml_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
return cls(**config)
|
||||
13
eval/r1/yaml/algebra.yaml
Normal file
13
eval/r1/yaml/algebra.yaml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
model: deepseek/deepseek-r1
|
||||
category: algebra
|
||||
datasets:
|
||||
- intermediate_integration
|
||||
- polynomial_equations
|
||||
- polynomial_multiplication
|
||||
- simple_equations
|
||||
- simple_integration
|
||||
- complex_arithmetic
|
||||
eval_dir: eval/r1
|
||||
dataset_size: 50
|
||||
dataset_seed: 42
|
||||
developer_role: system
|
||||
25
eval/r1/yaml/algorithmic.yaml
Normal file
25
eval/r1/yaml/algorithmic.yaml
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
model: deepseek/deepseek-r1
|
||||
category: algorithmic
|
||||
datasets:
|
||||
- base_conversion
|
||||
- binary_matrix
|
||||
- caesar _cipher
|
||||
- group_anagrams
|
||||
- isomorphic_strings
|
||||
- letter_counting
|
||||
- letter_jumble
|
||||
- number_filtering
|
||||
- number_sorting
|
||||
- palindrome
|
||||
- ransom_note
|
||||
- rotate_matrix
|
||||
- sentence_reordering
|
||||
- spell_backward
|
||||
- spiral_matrix
|
||||
- word_ladder
|
||||
- word_sequence_reversal
|
||||
- word_sorting
|
||||
eval_dir: eval/r1
|
||||
dataset_size: 50
|
||||
dataset_seed: 42
|
||||
developer_role: system
|
||||
11
eval/r1/yaml/cognition.yaml
Normal file
11
eval/r1/yaml/cognition.yaml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
model: deepseek/deepseek-r1
|
||||
category: cognition
|
||||
datasets:
|
||||
- color_cube_rotation
|
||||
- figlet_font
|
||||
- number_sequence
|
||||
- rubiks_cube
|
||||
eval_dir: eval/r1
|
||||
dataset_size: 50
|
||||
dataset_seed: 42
|
||||
developer_role: system
|
||||
11
eval/r1/yaml/logic.yaml
Normal file
11
eval/r1/yaml/logic.yaml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
model: deepseek/deepseek-r1
|
||||
category: logic
|
||||
datasets:
|
||||
- propositional_logic
|
||||
- self_reference
|
||||
- syllogism
|
||||
- zebra_puzzles
|
||||
eval_dir: eval/r1
|
||||
dataset_size: 50
|
||||
dataset_seed: 42
|
||||
developer_role: system
|
||||
Loading…
Add table
Add a link
Reference in a new issue