Eval sampling settings for generation (temperature, top-p, max_tokens) (#242)

* feat: Add sampling parameters to eval configuration and API call
* feat: Add support for system_prompt_id and optional system_prompt configuration
This commit is contained in:
Andreas Köpf 2025-02-28 11:48:37 +01:00 committed by GitHub
parent b1c8840129
commit b4207162ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 83 additions and 22 deletions

View file

@ -1,6 +1,7 @@
"""Configuration classes for the evaluation script"""
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional
@ -43,17 +44,51 @@ class CategoryConfig:
class EvalConfig:
"""Global evaluation configuration"""
model: str
provider: Optional[str] = None
system_prompt: str = SYSTEM_PROMPTS["default"]
system_role: str = "system"
output_dir: str = "results"
max_concurrent: int = 10
default_size: int = 500
default_seed: Optional[int] = None
save_metadata: bool = False
save_full_results: bool = False
categories: list[CategoryConfig] = field(default_factory=list)
model: str # Model identifier (e.g., "meta-llama/llama-3.3-70b-instruct")
provider: Optional[str] = None # Provider name for OpenRouter (e.g., "Anthropic", "OpenAI")
system_prompt: Optional[str] = None # Custom system prompt text (overrides system_prompt_id)
system_prompt_id: Optional[str] = None # ID of predefined system prompt from SYSTEM_PROMPTS
system_role: str = "system" # Role for the system message (usually "system")
output_dir: str = "results" # Directory to save evaluation results
max_concurrent: int = 10 # Maximum number of concurrent API calls
default_size: int = 500 # Default dataset size if not specified for a dataset
default_seed: Optional[int] = None # Default random seed if not specified for a dataset
save_metadata: bool = False # Whether to include dataset entry metadata in results
save_full_results: bool = False # Whether to save the full results file
# Sampling parameters
max_tokens: Optional[int] = 32768 # Maximum number of tokens to generate
temperature: Optional[float] = 0.6 # Sampling temperature (higher = more random)
top_p: Optional[float] = 0.95 # Nucleus sampling parameter (lower = more deterministic)
categories: list[CategoryConfig] = field(default_factory=list) # List of category configurations
def get_system_prompt(self) -> str:
"""Get the system prompt to use for evaluation.
Returns:
The system prompt string to use
"""
if self.system_prompt is not None and self.system_prompt_id is not None:
logging.warning(
"Both system_prompt and system_prompt_id are specified in the configuration. "
"Using system_prompt and ignoring system_prompt_id."
)
return self.system_prompt
if self.system_prompt is not None:
return self.system_prompt
if self.system_prompt_id is not None:
if self.system_prompt_id in SYSTEM_PROMPTS:
return SYSTEM_PROMPTS[self.system_prompt_id]
else:
logging.warning(
f"System prompt ID '{self.system_prompt_id}' not found in SYSTEM_PROMPTS. "
f"Using default system prompt instead."
)
# Default case: use the default system prompt
return SYSTEM_PROMPTS["default"]
@classmethod
def from_json(cls, json_path: str) -> "EvalConfig":
@ -129,11 +164,16 @@ class EvalConfig:
return cls(
model=config_data.get("model"),
provider=config_data.get("provider", "openai"),
system_prompt=config_data.get("system_prompt", SYSTEM_PROMPTS["default"]),
system_prompt=config_data.get("system_prompt"),
system_prompt_id=config_data.get("system_prompt_id"),
system_role=config_data.get("system_role", "system"),
output_dir=config_data.get("output_dir", "results"),
max_concurrent=config_data.get("max_concurrent", 10),
save_metadata=config_data.get("save_metadata", False),
save_full_results=config_data.get("save_full_results", False),
# Sampling parameters
max_tokens=config_data.get("max_tokens", 32768),
temperature=config_data.get("temperature", 0.6),
top_p=config_data.get("top_p", 0.95),
categories=categories,
)