diff --git a/eval/eval.py b/eval/eval.py index ce117d4d..d2e24555 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -1,21 +1,20 @@ -import time import argparse import asyncio import json import os +import time from datetime import datetime from typing import Any, Dict, List -from tqdm.asyncio import tqdm_asyncio + from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio from reasoning_gym.factory import create_dataset + class AsyncOpenRouterEvaluator: def __init__(self, model: str, max_concurrent: int = 10): - self.client = AsyncOpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=os.getenv("OPENROUTER_API_KEY") - ) + self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY")) self.model = model self.extra_headers = {} self.max_concurrent = max_concurrent @@ -26,9 +25,7 @@ class AsyncOpenRouterEvaluator: async with self.semaphore: try: completion = await self.client.chat.completions.create( - extra_headers=self.extra_headers, - model=self.model, - messages=[{"role": "user", "content": prompt}] + extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}] ) return completion.choices[0].message.content except Exception as e: @@ -39,7 +36,7 @@ class AsyncOpenRouterEvaluator: """Process a single question and return the result.""" response = await self.get_model_response(entry["question"]) score = dataset.score_answer(answer=response, entry=entry) - + return { "question": entry["question"], "expected_answer": entry["answer"], @@ -52,24 +49,18 @@ class AsyncOpenRouterEvaluator: """Evaluate a single dataset with concurrent question processing.""" dataset_name = dataset_config.pop("name") print(f"\nEvaluating dataset: {dataset_name}") - + try: # Create dataset with its specific configuration data = create_dataset(dataset_name, **dataset_config) all_entries = list(data) - + # Process all questions concurrently - tasks = [ - self.process_single_question(entry, data) - for entry in all_entries - ] - + tasks = [self.process_single_question(entry, data) for entry in all_entries] + # Use tqdm to track progress - results = await tqdm_asyncio.gather( - *tasks, - desc=f"Processing {dataset_name}" - ) - + results = await tqdm_asyncio.gather(*tasks, desc=f"Processing {dataset_name}") + # Calculate aggregate metrics total_score = sum(r["score"] for r in results) metrics = { @@ -81,31 +72,28 @@ class AsyncOpenRouterEvaluator: "timestamp": datetime.now().isoformat(), "config": dataset_config, } - + return {"metrics": metrics, "results": results} - + except Exception as e: print(f"Error evaluating dataset {dataset_name}: {str(e)}") return None async def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Evaluate multiple datasets concurrently.""" - tasks = [ - self.evaluate_dataset(config) - for config in dataset_configs - ] - + tasks = [self.evaluate_dataset(config) for config in dataset_configs] + # Process all datasets concurrently results = await asyncio.gather(*tasks) return [r for r in results if r is not None] + async def main_async(): parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets") parser.add_argument("--model", required=True, help="Model to evaluate") parser.add_argument("--config", required=True, help="Path to JSON configuration file") parser.add_argument("--output-dir", default="results", help="Output directory") - parser.add_argument("--max-concurrent", type=int, default=10, - help="Maximum number of concurrent API calls") + parser.add_argument("--max-concurrent", type=int, default=10, help="Maximum number of concurrent API calls") args = parser.parse_args() @@ -116,18 +104,14 @@ async def main_async(): with open(args.config, "r") as f: dataset_configs = json.load(f) - evaluator = AsyncOpenRouterEvaluator( - model=args.model, - max_concurrent=args.max_concurrent - ) - + evaluator = AsyncOpenRouterEvaluator(model=args.model, max_concurrent=args.max_concurrent) + eval_start_time = time.time() all_results = await evaluator.evaluate_datasets(dataset_configs) - print(f'Time taken to collect evaluation data: {time.time() - eval_start_time:.2f} seconds') + print(f"Time taken to collect evaluation data: {time.time() - eval_start_time:.2f} seconds") # Save results output_file = os.path.join( - args.output_dir, - f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + args.output_dir, f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" ) with open(output_file, "w") as f: @@ -148,8 +132,7 @@ async def main_async(): summary.append(summary_entry) summary_file = os.path.join( - args.output_dir, - f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + args.output_dir, f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" ) with open(summary_file, "w") as f: @@ -165,8 +148,10 @@ async def main_async(): print(f"\nDetailed results saved to: {output_file}") print(f"Summary saved to: {summary_file}") + def main(): asyncio.run(main_async()) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json index 67a304d3..d9d441f9 100644 --- a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json @@ -36,4 +36,4 @@ "seed": 42 } } -] \ No newline at end of file +] diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json index 274f3f8d..ab447f29 100644 --- a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json @@ -36,4 +36,4 @@ "seed": 42 } } -] \ No newline at end of file +] diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json index d85fb0be..1d067324 100644 --- a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json @@ -36,4 +36,4 @@ "seed": 42 } } -] \ No newline at end of file +] diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json index 8febdc63..611ecf28 100644 --- a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json @@ -36,4 +36,4 @@ "seed": 42 } } -] \ No newline at end of file +] diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_214631.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_214631.json index 587afe36..55268914 100644 --- a/eval/results/summary_google_gemini-2.0-flash-001_20250210_214631.json +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_214631.json @@ -36,4 +36,4 @@ "seed": 42 } } -] \ No newline at end of file +]