mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
[eval-v1] pre commit formatting
This commit is contained in:
parent
df5438498e
commit
9e4870125d
6 changed files with 32 additions and 47 deletions
69
eval/eval.py
69
eval/eval.py
|
|
@ -1,21 +1,20 @@
|
|||
import time
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from reasoning_gym.factory import create_dataset
|
||||
|
||||
|
||||
class AsyncOpenRouterEvaluator:
|
||||
def __init__(self, model: str, max_concurrent: int = 10):
|
||||
self.client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY")
|
||||
)
|
||||
self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
|
||||
self.model = model
|
||||
self.extra_headers = {}
|
||||
self.max_concurrent = max_concurrent
|
||||
|
|
@ -26,9 +25,7 @@ class AsyncOpenRouterEvaluator:
|
|||
async with self.semaphore:
|
||||
try:
|
||||
completion = await self.client.chat.completions.create(
|
||||
extra_headers=self.extra_headers,
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
except Exception as e:
|
||||
|
|
@ -39,7 +36,7 @@ class AsyncOpenRouterEvaluator:
|
|||
"""Process a single question and return the result."""
|
||||
response = await self.get_model_response(entry["question"])
|
||||
score = dataset.score_answer(answer=response, entry=entry)
|
||||
|
||||
|
||||
return {
|
||||
"question": entry["question"],
|
||||
"expected_answer": entry["answer"],
|
||||
|
|
@ -52,24 +49,18 @@ class AsyncOpenRouterEvaluator:
|
|||
"""Evaluate a single dataset with concurrent question processing."""
|
||||
dataset_name = dataset_config.pop("name")
|
||||
print(f"\nEvaluating dataset: {dataset_name}")
|
||||
|
||||
|
||||
try:
|
||||
# Create dataset with its specific configuration
|
||||
data = create_dataset(dataset_name, **dataset_config)
|
||||
all_entries = list(data)
|
||||
|
||||
|
||||
# Process all questions concurrently
|
||||
tasks = [
|
||||
self.process_single_question(entry, data)
|
||||
for entry in all_entries
|
||||
]
|
||||
|
||||
tasks = [self.process_single_question(entry, data) for entry in all_entries]
|
||||
|
||||
# Use tqdm to track progress
|
||||
results = await tqdm_asyncio.gather(
|
||||
*tasks,
|
||||
desc=f"Processing {dataset_name}"
|
||||
)
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc=f"Processing {dataset_name}")
|
||||
|
||||
# Calculate aggregate metrics
|
||||
total_score = sum(r["score"] for r in results)
|
||||
metrics = {
|
||||
|
|
@ -81,31 +72,28 @@ class AsyncOpenRouterEvaluator:
|
|||
"timestamp": datetime.now().isoformat(),
|
||||
"config": dataset_config,
|
||||
}
|
||||
|
||||
|
||||
return {"metrics": metrics, "results": results}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error evaluating dataset {dataset_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Evaluate multiple datasets concurrently."""
|
||||
tasks = [
|
||||
self.evaluate_dataset(config)
|
||||
for config in dataset_configs
|
||||
]
|
||||
|
||||
tasks = [self.evaluate_dataset(config) for config in dataset_configs]
|
||||
|
||||
# Process all datasets concurrently
|
||||
results = await asyncio.gather(*tasks)
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
async def main_async():
|
||||
parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets")
|
||||
parser.add_argument("--model", required=True, help="Model to evaluate")
|
||||
parser.add_argument("--config", required=True, help="Path to JSON configuration file")
|
||||
parser.add_argument("--output-dir", default="results", help="Output directory")
|
||||
parser.add_argument("--max-concurrent", type=int, default=10,
|
||||
help="Maximum number of concurrent API calls")
|
||||
parser.add_argument("--max-concurrent", type=int, default=10, help="Maximum number of concurrent API calls")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -116,18 +104,14 @@ async def main_async():
|
|||
with open(args.config, "r") as f:
|
||||
dataset_configs = json.load(f)
|
||||
|
||||
evaluator = AsyncOpenRouterEvaluator(
|
||||
model=args.model,
|
||||
max_concurrent=args.max_concurrent
|
||||
)
|
||||
|
||||
evaluator = AsyncOpenRouterEvaluator(model=args.model, max_concurrent=args.max_concurrent)
|
||||
|
||||
eval_start_time = time.time()
|
||||
all_results = await evaluator.evaluate_datasets(dataset_configs)
|
||||
print(f'Time taken to collect evaluation data: {time.time() - eval_start_time:.2f} seconds')
|
||||
print(f"Time taken to collect evaluation data: {time.time() - eval_start_time:.2f} seconds")
|
||||
# Save results
|
||||
output_file = os.path.join(
|
||||
args.output_dir,
|
||||
f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
args.output_dir, f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
)
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
|
|
@ -148,8 +132,7 @@ async def main_async():
|
|||
summary.append(summary_entry)
|
||||
|
||||
summary_file = os.path.join(
|
||||
args.output_dir,
|
||||
f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
args.output_dir, f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
)
|
||||
|
||||
with open(summary_file, "w") as f:
|
||||
|
|
@ -165,8 +148,10 @@ async def main_async():
|
|||
print(f"\nDetailed results saved to: {output_file}")
|
||||
print(f"Summary saved to: {summary_file}")
|
||||
|
||||
|
||||
def main():
|
||||
asyncio.run(main_async())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue