[eval-basic] run precommit formatting

This commit is contained in:
rishabhranawat 2025-02-09 22:40:45 -08:00
parent 75cfd31ec2
commit c214724a46
10 changed files with 53 additions and 67 deletions

View file

@ -18,4 +18,4 @@
"size": 10, "size": 10,
"seed": 42 "seed": 42
} }
] ]

View file

@ -1,18 +1,17 @@
import argparse import argparse
from datetime import datetime
import json import json
import os import os
from openai import OpenAI from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
from openai import OpenAI
from reasoning_gym.factory import DATASETS, create_dataset from reasoning_gym.factory import DATASETS, create_dataset
class OpenRouterEvaluator: class OpenRouterEvaluator:
def __init__(self, model: str): def __init__(self, model: str):
self.client = OpenAI( self.client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv('OPENROUTER_API_KEY')
)
self.model = model self.model = model
self.extra_headers = {} self.extra_headers = {}
@ -20,12 +19,7 @@ class OpenRouterEvaluator:
"""Get response from the model via OpenRouter API.""" """Get response from the model via OpenRouter API."""
try: try:
completion = self.client.chat.completions.create( completion = self.client.chat.completions.create(
extra_headers=self.extra_headers, extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}]
model=self.model,
messages=[{
"role": "user",
"content": prompt
}]
) )
return completion.choices[0].message.content return completion.choices[0].message.content
except Exception as e: except Exception as e:
@ -35,27 +29,27 @@ class OpenRouterEvaluator:
def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def evaluate_datasets(self, dataset_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Evaluate model on multiple datasets with their respective configurations.""" """Evaluate model on multiple datasets with their respective configurations."""
all_results = [] all_results = []
for dataset_config in dataset_configs: for dataset_config in dataset_configs:
dataset_name = dataset_config.pop('name') dataset_name = dataset_config.pop("name")
print(f"\nEvaluating dataset: {dataset_name}") print(f"\nEvaluating dataset: {dataset_name}")
try: try:
# Create dataset with its specific configuration # Create dataset with its specific configuration
data = create_dataset(dataset_name, **dataset_config) data = create_dataset(dataset_name, **dataset_config)
results = [] results = []
for entry in data: for entry in data:
try: try:
response = self.get_model_response(entry['question']) response = self.get_model_response(entry["question"])
score = data.score_answer(answer=response, entry=entry) score = data.score_answer(answer=response, entry=entry)
result = { result = {
'question': entry['question'], "question": entry["question"],
'expected_answer': entry['answer'], "expected_answer": entry["answer"],
'model_answer': response, "model_answer": response,
'score': score, "score": score,
'metadata': entry['metadata'] "metadata": entry["metadata"],
} }
results.append(result) results.append(result)
print(f"Processed question {len(results)}/{len(data)}. Score: {score}") print(f"Processed question {len(results)}/{len(data)}. Score: {score}")
@ -65,21 +59,18 @@ class OpenRouterEvaluator:
print(f"Error: {str(e)}") print(f"Error: {str(e)}")
# Calculate aggregate metrics # Calculate aggregate metrics
total_score = sum(r['score'] for r in results) total_score = sum(r["score"] for r in results)
metrics = { metrics = {
'dataset_name': dataset_name, "dataset_name": dataset_name,
'model': self.model, "model": self.model,
'size': len(data), "size": len(data),
'average_score': total_score / len(results) if results else 0, "average_score": total_score / len(results) if results else 0,
'total_examples': len(results), "total_examples": len(results),
'timestamp': datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
'config': dataset_config "config": dataset_config,
} }
all_results.append({ all_results.append({"metrics": metrics, "results": results})
'metrics': metrics,
'results': results
})
except Exception as e: except Exception as e:
print(f"Error evaluating dataset {dataset_name}: {str(e)}") print(f"Error evaluating dataset {dataset_name}: {str(e)}")
@ -89,13 +80,10 @@ class OpenRouterEvaluator:
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets")
description='Evaluate models on reasoning datasets') parser.add_argument("--model", required=True, help="Model to evaluate")
parser.add_argument('--model', required=True, help='Model to evaluate') parser.add_argument("--config", required=True, help="Path to JSON configuration file")
parser.add_argument('--config', required=True, parser.add_argument("--output-dir", default="results", help="Output directory")
help='Path to JSON configuration file')
parser.add_argument('--output-dir', default='results',
help='Output directory')
args = parser.parse_args() args = parser.parse_args()
@ -103,7 +91,7 @@ def main():
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Load dataset configurations # Load dataset configurations
with open(args.config, 'r') as f: with open(args.config, "r") as f:
dataset_configs = json.load(f) dataset_configs = json.load(f)
evaluator = OpenRouterEvaluator(model=args.model) evaluator = OpenRouterEvaluator(model=args.model)
@ -111,35 +99,33 @@ def main():
# Save results # Save results
output_file = os.path.join( output_file = os.path.join(
args.output_dir, args.output_dir, f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
f"evaluation_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
) )
# Save detailed results # Save detailed results
with open(output_file, 'w') as f: with open(output_file, "w") as f:
json.dump(all_results, f, indent=2) json.dump(all_results, f, indent=2)
# Create summary # Create summary
summary = [] summary = []
for result in all_results: for result in all_results:
metrics = result['metrics'] metrics = result["metrics"]
summary_entry = { summary_entry = {
'dataset_name': metrics['dataset_name'], "dataset_name": metrics["dataset_name"],
'model': metrics['model'], "model": metrics["model"],
'average_score': metrics['average_score'], "average_score": metrics["average_score"],
'total_examples': metrics['total_examples'], "total_examples": metrics["total_examples"],
'timestamp': metrics['timestamp'], "timestamp": metrics["timestamp"],
'config': metrics['config'] "config": metrics["config"],
} }
summary.append(summary_entry) summary.append(summary_entry)
# Save summary to a separate file # Save summary to a separate file
summary_file = os.path.join( summary_file = os.path.join(
args.output_dir, args.output_dir, f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
f"summary_{args.model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
) )
with open(summary_file, 'w') as f: with open(summary_file, "w") as f:
json.dump(summary, f, indent=2) json.dump(summary, f, indent=2)
# Print summary # Print summary
@ -148,10 +134,10 @@ def main():
print(f"\nDataset: {entry['dataset_name']}") print(f"\nDataset: {entry['dataset_name']}")
print(f"Average Score: {entry['average_score']:.2%}") print(f"Average Score: {entry['average_score']:.2%}")
print(f"Total Examples: {entry['total_examples']}") print(f"Total Examples: {entry['total_examples']}")
print(f"\nDetailed results saved to: {output_file}") print(f"\nDetailed results saved to: {output_file}")
print(f"Summary saved to: {summary_file}") print(f"Summary saved to: {summary_file}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -27,4 +27,4 @@ for model in "${MODELS[@]}"; do
--output-dir "$OUTPUT_DIR" --output-dir "$OUTPUT_DIR"
done done
echo "All evaluations completed!" echo "All evaluations completed!"

View file

@ -644,4 +644,4 @@
} }
] ]
} }
] ]

View file

@ -9,4 +9,4 @@
"timestamp": "2025-02-10T06:06:08.539389" "timestamp": "2025-02-10T06:06:08.539389"
}, },
"results": [] "results": []
} }

View file

@ -183,4 +183,4 @@
} }
} }
] ]
} }

View file

@ -9,4 +9,4 @@
"timestamp": "2025-02-10T06:06:10.638347" "timestamp": "2025-02-10T06:06:10.638347"
}, },
"results": [] "results": []
} }

View file

@ -36,4 +36,4 @@
"seed": 42 "seed": 42
} }
} }
] ]

View file

@ -5,4 +5,4 @@ isort>=5.13.2
flake8>=7.1.1 flake8>=7.1.1
mypy>=1.14.1 mypy>=1.14.1
pre-commit>=4.1.0 pre-commit>=4.1.0
openai>=1.61.1 openai>=1.61.1