From 247464a47d7db5405c587cec10edf9bad0b563a8 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Mon, 10 Feb 2025 21:35:46 -0800 Subject: [PATCH] [eval-v1] async to speed up inference/evaluation --- eval/eval.py | 181 ++++++++++-------- ..._gemini-2.0-flash-001_20250210_212645.json | 39 ++++ ..._gemini-2.0-flash-001_20250210_212924.json | 39 ++++ ..._gemini-2.0-flash-001_20250210_213352.json | 39 ++++ ..._gemini-2.0-flash-001_20250210_213418.json | 39 ++++ 5 files changed, 261 insertions(+), 76 deletions(-) create mode 100644 eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json create mode 100644 eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json create mode 100644 eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json create mode 100644 eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json diff --git a/eval/eval.py b/eval/eval.py index f8952e10..e31bf25f 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -1,89 +1,111 @@ +import time import argparse +import asyncio import json import os from datetime import datetime from typing import Any, Dict, List +from tqdm.asyncio import tqdm_asyncio +from openai import AsyncOpenAI -from openai import OpenAI +from reasoning_gym.factory import create_dataset -from reasoning_gym.factory import DATASETS, create_dataset - - -class OpenRouterEvaluator: - def __init__(self, model: str): - self.client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY")) +class AsyncOpenRouterEvaluator: + def __init__(self, model: str, max_concurrent: int = 10): + self.client = AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY") + ) self.model = model self.extra_headers = {} + self.max_concurrent = max_concurrent + self.semaphore = asyncio.Semaphore(max_concurrent) - def get_model_response(self, prompt: str) -> str: - """Get response from the model via OpenRouter API.""" - try: - completion = self.client.chat.completions.create( - extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}] - ) - return completion.choices[0].message.content - except Exception as e: - print(f"Error calling OpenRouter API: {str(e)}") - raise - - def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Evaluate model on multiple datasets with their respective configurations.""" - all_results = [] - - for dataset_config in dataset_configs: - dataset_name = dataset_config.pop("name") - print(f"\nEvaluating dataset: {dataset_name}") - + async def get_model_response(self, prompt: str) -> str: + """Get response from the model via OpenRouter API with rate limiting.""" + async with self.semaphore: try: - # Create dataset with its specific configuration - data = create_dataset(dataset_name, **dataset_config) - results = [] - - for entry in data: - try: - response = self.get_model_response(entry["question"]) - score = data.score_answer(answer=response, entry=entry) - - result = { - "question": entry["question"], - "expected_answer": entry["answer"], - "model_answer": response, - "score": score, - "metadata": entry["metadata"], - } - results.append(result) - print(f"Processed question {len(results)}/{len(data)}. Score: {score}") - - except Exception as e: - print(f"Error processing question: {entry['question']}") - print(f"Error: {str(e)}") - - # Calculate aggregate metrics - total_score = sum(r["score"] for r in results) - metrics = { - "dataset_name": dataset_name, - "model": self.model, - "size": len(data), - "average_score": total_score / len(results) if results else 0, - "total_examples": len(results), - "timestamp": datetime.now().isoformat(), - "config": dataset_config, - } - - all_results.append({"metrics": metrics, "results": results}) - + completion = await self.client.chat.completions.create( + extra_headers=self.extra_headers, + model=self.model, + messages=[{"role": "user", "content": prompt}] + ) + return completion.choices[0].message.content except Exception as e: - print(f"Error evaluating dataset {dataset_name}: {str(e)}") - continue + print(f"Error calling OpenRouter API: {str(e)}") + raise - return all_results + async def process_single_question(self, entry: Dict, dataset) -> Dict: + """Process a single question and return the result.""" + response = await self.get_model_response(entry["question"]) + score = dataset.score_answer(answer=response, entry=entry) + + return { + "question": entry["question"], + "expected_answer": entry["answer"], + "model_answer": response, + "score": score, + "metadata": entry["metadata"], + } + async def evaluate_dataset(self, dataset_config: Dict[str, Any]) -> Dict[str, Any]: + """Evaluate a single dataset with concurrent question processing.""" + dataset_name = dataset_config.pop("name") + print(f"\nEvaluating dataset: {dataset_name}") + + try: + # Create dataset with its specific configuration + data = create_dataset(dataset_name, **dataset_config) + all_entries = list(data) + + # Process all questions concurrently + tasks = [ + self.process_single_question(entry, data) + for entry in all_entries + ] + + # Use tqdm to track progress + results = await tqdm_asyncio.gather( + *tasks, + desc=f"Processing {dataset_name}" + ) + + # Calculate aggregate metrics + total_score = sum(r["score"] for r in results) + metrics = { + "dataset_name": dataset_name, + "model": self.model, + "size": len(data), + "average_score": total_score / len(results) if results else 0, + "total_examples": len(results), + "timestamp": datetime.now().isoformat(), + "config": dataset_config, + } + + return {"metrics": metrics, "results": results} + + except Exception as e: + print(f"Error evaluating dataset {dataset_name}: {str(e)}") + return None -def main(): + async def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Evaluate multiple datasets concurrently.""" + tasks = [ + self.evaluate_dataset(config) + for config in dataset_configs + ] + + # Process all datasets concurrently + results = await asyncio.gather(*tasks) + return [r for r in results if r is not None] + +async def main_async(): parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets") parser.add_argument("--model", required=True, help="Model to evaluate") parser.add_argument("--config", required=True, help="Path to JSON configuration file") parser.add_argument("--output-dir", default="results", help="Output directory") + parser.add_argument("--max-concurrent", type=int, default=10, + help="Maximum number of concurrent API calls") args = parser.parse_args() @@ -94,19 +116,24 @@ def main(): with open(args.config, "r") as f: dataset_configs = json.load(f) - evaluator = OpenRouterEvaluator(model=args.model) - all_results = evaluator.evaluate_datasets(dataset_configs) - + evaluator = AsyncOpenRouterEvaluator( + model=args.model, + max_concurrent=args.max_concurrent + ) + + eval_start_time = time.time() + all_results = await evaluator.evaluate_datasets(dataset_configs) + print(f'Time taken to collect evaluation data: {time.time() - eval_start_time}') # Save results output_file = os.path.join( - args.output_dir, f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + args.output_dir, + f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" ) - # Save detailed results with open(output_file, "w") as f: json.dump(all_results, f, indent=2) - # Create summary + # Create and save summary summary = [] for result in all_results: metrics = result["metrics"] @@ -120,9 +147,9 @@ def main(): } summary.append(summary_entry) - # Save summary to a separate file summary_file = os.path.join( - args.output_dir, f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + args.output_dir, + f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" ) with open(summary_file, "w") as f: @@ -138,6 +165,8 @@ def main(): print(f"\nDetailed results saved to: {output_file}") print(f"Summary saved to: {summary_file}") +def main(): + asyncio.run(main_async()) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json new file mode 100644 index 00000000..67a304d3 --- /dev/null +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212645.json @@ -0,0 +1,39 @@ +[ + { + "dataset_name": "letter_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.10800000000000001, + "total_examples": 10, + "timestamp": "2025-02-10T21:26:40.575060", + "config": { + "min_words": 5, + "max_words": 15, + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "propositional_logic", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.059, + "total_examples": 10, + "timestamp": "2025-02-10T21:26:44.955201", + "config": { + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "leg_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.40199999999999997, + "total_examples": 10, + "timestamp": "2025-02-10T21:26:45.852518", + "config": { + "min_animals": 3, + "max_animals": 8, + "size": 10, + "seed": 42 + } + } +] \ No newline at end of file diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json new file mode 100644 index 00000000..274f3f8d --- /dev/null +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_212924.json @@ -0,0 +1,39 @@ +[ + { + "dataset_name": "letter_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.157, + "total_examples": 10, + "timestamp": "2025-02-10T21:29:18.766288", + "config": { + "min_words": 5, + "max_words": 15, + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "propositional_logic", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.059, + "total_examples": 10, + "timestamp": "2025-02-10T21:29:24.026918", + "config": { + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "leg_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.40199999999999997, + "total_examples": 10, + "timestamp": "2025-02-10T21:29:23.650182", + "config": { + "min_animals": 3, + "max_animals": 8, + "size": 10, + "seed": 42 + } + } +] \ No newline at end of file diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json new file mode 100644 index 00000000..d85fb0be --- /dev/null +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213352.json @@ -0,0 +1,39 @@ +[ + { + "dataset_name": "letter_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.157, + "total_examples": 10, + "timestamp": "2025-02-10T21:33:46.747429", + "config": { + "min_words": 5, + "max_words": 15, + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "propositional_logic", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.059, + "total_examples": 10, + "timestamp": "2025-02-10T21:33:51.422633", + "config": { + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "leg_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.40199999999999997, + "total_examples": 10, + "timestamp": "2025-02-10T21:33:52.022623", + "config": { + "min_animals": 3, + "max_animals": 8, + "size": 10, + "seed": 42 + } + } +] \ No newline at end of file diff --git a/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json new file mode 100644 index 00000000..8febdc63 --- /dev/null +++ b/eval/results/summary_google_gemini-2.0-flash-001_20250210_213418.json @@ -0,0 +1,39 @@ +[ + { + "dataset_name": "letter_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.157, + "total_examples": 10, + "timestamp": "2025-02-10T21:34:13.347168", + "config": { + "min_words": 5, + "max_words": 15, + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "propositional_logic", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.10800000000000001, + "total_examples": 10, + "timestamp": "2025-02-10T21:34:18.146056", + "config": { + "size": 10, + "seed": 42 + } + }, + { + "dataset_name": "leg_counting", + "model": "google/gemini-2.0-flash-001", + "average_score": 0.40199999999999997, + "total_examples": 10, + "timestamp": "2025-02-10T21:34:18.315364", + "config": { + "min_animals": 3, + "max_animals": 8, + "size": 10, + "seed": 42 + } + } +] \ No newline at end of file