From b4207162ffbb4e21b0b1c661aed0a7f755a6c87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 28 Feb 2025 11:48:37 +0100 Subject: [PATCH] Eval sampling settings for generation (temperature, top-p, max_tokens) (#242) * feat: Add sampling parameters to eval configuration and API call * feat: Add support for system_prompt_id and optional system_prompt configuration --- eval/README.md | 5 ++++ eval/eval.py | 36 ++++++++++++++++++------- eval/eval_config.py | 64 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 83 insertions(+), 22 deletions(-) diff --git a/eval/README.md b/eval/README.md index 218cd7e5..8b18fec0 100644 --- a/eval/README.md +++ b/eval/README.md @@ -44,6 +44,11 @@ output_dir: "results" max_concurrent: 10 default_size: 20 # Default size for all datasets default_seed: 42 # Default seed for all datasets +max_tokens: 32768 # Maximum generation length (optional) +temperature: 0.6 # Generation temperature (optional) +top_p: 0.95 # Top-p sampling parameter (optional) +system_prompt_id: "default" # Use a predefined system prompt by ID (optional) +# system_prompt: "Your custom system prompt here" # Or specify a custom system prompt directly categories: - category: "algebra" diff --git a/eval/eval.py b/eval/eval.py index e517b96c..8c3f2f96 100755 --- a/eval/eval.py +++ b/eval/eval.py @@ -121,11 +121,19 @@ class AsyncModelEvaluator: params = { "model": self.config.model, "messages": [ - {"role": self.config.system_role, "content": self.config.system_prompt}, + {"role": self.config.system_role, "content": self.config.get_system_prompt()}, {"role": "user", "content": prompt}, ], } + # Add sampling parameters if specified + if self.config.max_tokens is not None: + params["max_tokens"] = self.config.max_tokens + if self.config.temperature is not None: + params["temperature"] = self.config.temperature + if self.config.top_p is not None: + params["top_p"] = self.config.top_p + # Add provider configuration if specified if self.config.provider: params["extra_body"] = {"provider": {"order": [self.config.provider], "allow_fallbacks": False}} @@ -253,7 +261,7 @@ class AsyncModelEvaluator: "average_score": average_score, "total_examples": len(results), "config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params}, - "system_prompt": self.config.system_prompt, + "system_prompt": self.config.get_system_prompt(), "results": results, } @@ -265,7 +273,7 @@ class AsyncModelEvaluator: "average_score": 0.0, "total_examples": 0, "config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params}, - "system_prompt": self.config.system_prompt, + "system_prompt": self.config.get_system_prompt(), "error": str(e), "results": [], } @@ -310,6 +318,9 @@ class AsyncModelEvaluator: "provider": self.config.provider, "git_hash": self.git_hash, "duration_seconds": (datetime.now() - self.start_time).total_seconds(), + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "top_p": self.config.top_p, }, "categories": category_results, } @@ -384,13 +395,18 @@ class AsyncModelEvaluator: with open(results_path, "w") as f: json.dump(results, f, indent=2) - # Add timestamp, git hash, model, provider, and duration to summary + # Add timestamp, git hash, model, provider, sampling parameters, and duration to summary summary_data = results["summary"].copy() summary_data["timestamp"] = self.start_time.isoformat() summary_data["git_hash"] = self.git_hash summary_data["model"] = self.config.model summary_data["provider"] = self.config.provider - summary_data["system_prompt"] = self.config.system_prompt + summary_data["system_prompt"] = self.config.get_system_prompt() + if self.config.system_prompt_id: + summary_data["system_prompt_id"] = self.config.system_prompt_id + summary_data["max_tokens"] = self.config.max_tokens + summary_data["temperature"] = self.config.temperature + summary_data["top_p"] = self.config.top_p summary_data["duration_seconds"] = results["metadata"]["duration_seconds"] # Save summary @@ -422,11 +438,11 @@ class AsyncModelEvaluator: print("------------------") print(f"Model: {self.config.model}") print(f"Provider: {self.config.provider}") - print( - f"System Prompt: {self.config.system_prompt[:50]}..." - if len(self.config.system_prompt) > 50 - else self.config.system_prompt - ) + system_prompt = self.config.get_system_prompt() + print(f"System Prompt: {system_prompt[:50]}..." if len(system_prompt) > 50 else system_prompt) + print(f"Max Tokens: {self.config.max_tokens}") + print(f"Temperature: {self.config.temperature}") + print(f"Top-p: {self.config.top_p}") print(f"Git Hash: {self.git_hash}") print(f"Duration: {results['metadata']['duration_seconds']:.2f} seconds") print() diff --git a/eval/eval_config.py b/eval/eval_config.py index 6099069c..db0fe519 100644 --- a/eval/eval_config.py +++ b/eval/eval_config.py @@ -1,6 +1,7 @@ """Configuration classes for the evaluation script""" import json +import logging import re from dataclasses import dataclass, field from typing import Any, Optional @@ -43,17 +44,51 @@ class CategoryConfig: class EvalConfig: """Global evaluation configuration""" - model: str - provider: Optional[str] = None - system_prompt: str = SYSTEM_PROMPTS["default"] - system_role: str = "system" - output_dir: str = "results" - max_concurrent: int = 10 - default_size: int = 500 - default_seed: Optional[int] = None - save_metadata: bool = False - save_full_results: bool = False - categories: list[CategoryConfig] = field(default_factory=list) + model: str # Model identifier (e.g., "meta-llama/llama-3.3-70b-instruct") + provider: Optional[str] = None # Provider name for OpenRouter (e.g., "Anthropic", "OpenAI") + system_prompt: Optional[str] = None # Custom system prompt text (overrides system_prompt_id) + system_prompt_id: Optional[str] = None # ID of predefined system prompt from SYSTEM_PROMPTS + system_role: str = "system" # Role for the system message (usually "system") + output_dir: str = "results" # Directory to save evaluation results + max_concurrent: int = 10 # Maximum number of concurrent API calls + default_size: int = 500 # Default dataset size if not specified for a dataset + default_seed: Optional[int] = None # Default random seed if not specified for a dataset + save_metadata: bool = False # Whether to include dataset entry metadata in results + save_full_results: bool = False # Whether to save the full results file + # Sampling parameters + max_tokens: Optional[int] = 32768 # Maximum number of tokens to generate + temperature: Optional[float] = 0.6 # Sampling temperature (higher = more random) + top_p: Optional[float] = 0.95 # Nucleus sampling parameter (lower = more deterministic) + + categories: list[CategoryConfig] = field(default_factory=list) # List of category configurations + + def get_system_prompt(self) -> str: + """Get the system prompt to use for evaluation. + + Returns: + The system prompt string to use + """ + if self.system_prompt is not None and self.system_prompt_id is not None: + logging.warning( + "Both system_prompt and system_prompt_id are specified in the configuration. " + "Using system_prompt and ignoring system_prompt_id." + ) + return self.system_prompt + + if self.system_prompt is not None: + return self.system_prompt + + if self.system_prompt_id is not None: + if self.system_prompt_id in SYSTEM_PROMPTS: + return SYSTEM_PROMPTS[self.system_prompt_id] + else: + logging.warning( + f"System prompt ID '{self.system_prompt_id}' not found in SYSTEM_PROMPTS. " + f"Using default system prompt instead." + ) + + # Default case: use the default system prompt + return SYSTEM_PROMPTS["default"] @classmethod def from_json(cls, json_path: str) -> "EvalConfig": @@ -129,11 +164,16 @@ class EvalConfig: return cls( model=config_data.get("model"), provider=config_data.get("provider", "openai"), - system_prompt=config_data.get("system_prompt", SYSTEM_PROMPTS["default"]), + system_prompt=config_data.get("system_prompt"), + system_prompt_id=config_data.get("system_prompt_id"), system_role=config_data.get("system_role", "system"), output_dir=config_data.get("output_dir", "results"), max_concurrent=config_data.get("max_concurrent", 10), save_metadata=config_data.get("save_metadata", False), save_full_results=config_data.get("save_full_results", False), + # Sampling parameters + max_tokens=config_data.get("max_tokens", 32768), + temperature=config_data.get("temperature", 0.6), + top_p=config_data.get("top_p", 0.95), categories=categories, )