mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Eval sampling settings for generation (temperature, top-p, max_tokens) (#242)
* feat: Add sampling parameters to eval configuration and API call * feat: Add support for system_prompt_id and optional system_prompt configuration
This commit is contained in:
parent
b1c8840129
commit
b4207162ff
3 changed files with 83 additions and 22 deletions
|
|
@ -44,6 +44,11 @@ output_dir: "results"
|
|||
max_concurrent: 10
|
||||
default_size: 20 # Default size for all datasets
|
||||
default_seed: 42 # Default seed for all datasets
|
||||
max_tokens: 32768 # Maximum generation length (optional)
|
||||
temperature: 0.6 # Generation temperature (optional)
|
||||
top_p: 0.95 # Top-p sampling parameter (optional)
|
||||
system_prompt_id: "default" # Use a predefined system prompt by ID (optional)
|
||||
# system_prompt: "Your custom system prompt here" # Or specify a custom system prompt directly
|
||||
|
||||
categories:
|
||||
- category: "algebra"
|
||||
|
|
|
|||
36
eval/eval.py
36
eval/eval.py
|
|
@ -121,11 +121,19 @@ class AsyncModelEvaluator:
|
|||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{"role": self.config.system_role, "content": self.config.system_prompt},
|
||||
{"role": self.config.system_role, "content": self.config.get_system_prompt()},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
}
|
||||
|
||||
# Add sampling parameters if specified
|
||||
if self.config.max_tokens is not None:
|
||||
params["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
params["temperature"] = self.config.temperature
|
||||
if self.config.top_p is not None:
|
||||
params["top_p"] = self.config.top_p
|
||||
|
||||
# Add provider configuration if specified
|
||||
if self.config.provider:
|
||||
params["extra_body"] = {"provider": {"order": [self.config.provider], "allow_fallbacks": False}}
|
||||
|
|
@ -253,7 +261,7 @@ class AsyncModelEvaluator:
|
|||
"average_score": average_score,
|
||||
"total_examples": len(results),
|
||||
"config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params},
|
||||
"system_prompt": self.config.system_prompt,
|
||||
"system_prompt": self.config.get_system_prompt(),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
|
@ -265,7 +273,7 @@ class AsyncModelEvaluator:
|
|||
"average_score": 0.0,
|
||||
"total_examples": 0,
|
||||
"config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params},
|
||||
"system_prompt": self.config.system_prompt,
|
||||
"system_prompt": self.config.get_system_prompt(),
|
||||
"error": str(e),
|
||||
"results": [],
|
||||
}
|
||||
|
|
@ -310,6 +318,9 @@ class AsyncModelEvaluator:
|
|||
"provider": self.config.provider,
|
||||
"git_hash": self.git_hash,
|
||||
"duration_seconds": (datetime.now() - self.start_time).total_seconds(),
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"temperature": self.config.temperature,
|
||||
"top_p": self.config.top_p,
|
||||
},
|
||||
"categories": category_results,
|
||||
}
|
||||
|
|
@ -384,13 +395,18 @@ class AsyncModelEvaluator:
|
|||
with open(results_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
# Add timestamp, git hash, model, provider, and duration to summary
|
||||
# Add timestamp, git hash, model, provider, sampling parameters, and duration to summary
|
||||
summary_data = results["summary"].copy()
|
||||
summary_data["timestamp"] = self.start_time.isoformat()
|
||||
summary_data["git_hash"] = self.git_hash
|
||||
summary_data["model"] = self.config.model
|
||||
summary_data["provider"] = self.config.provider
|
||||
summary_data["system_prompt"] = self.config.system_prompt
|
||||
summary_data["system_prompt"] = self.config.get_system_prompt()
|
||||
if self.config.system_prompt_id:
|
||||
summary_data["system_prompt_id"] = self.config.system_prompt_id
|
||||
summary_data["max_tokens"] = self.config.max_tokens
|
||||
summary_data["temperature"] = self.config.temperature
|
||||
summary_data["top_p"] = self.config.top_p
|
||||
summary_data["duration_seconds"] = results["metadata"]["duration_seconds"]
|
||||
|
||||
# Save summary
|
||||
|
|
@ -422,11 +438,11 @@ class AsyncModelEvaluator:
|
|||
print("------------------")
|
||||
print(f"Model: {self.config.model}")
|
||||
print(f"Provider: {self.config.provider}")
|
||||
print(
|
||||
f"System Prompt: {self.config.system_prompt[:50]}..."
|
||||
if len(self.config.system_prompt) > 50
|
||||
else self.config.system_prompt
|
||||
)
|
||||
system_prompt = self.config.get_system_prompt()
|
||||
print(f"System Prompt: {system_prompt[:50]}..." if len(system_prompt) > 50 else system_prompt)
|
||||
print(f"Max Tokens: {self.config.max_tokens}")
|
||||
print(f"Temperature: {self.config.temperature}")
|
||||
print(f"Top-p: {self.config.top_p}")
|
||||
print(f"Git Hash: {self.git_hash}")
|
||||
print(f"Duration: {results['metadata']['duration_seconds']:.2f} seconds")
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Configuration classes for the evaluation script"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
|
@ -43,17 +44,51 @@ class CategoryConfig:
|
|||
class EvalConfig:
|
||||
"""Global evaluation configuration"""
|
||||
|
||||
model: str
|
||||
provider: Optional[str] = None
|
||||
system_prompt: str = SYSTEM_PROMPTS["default"]
|
||||
system_role: str = "system"
|
||||
output_dir: str = "results"
|
||||
max_concurrent: int = 10
|
||||
default_size: int = 500
|
||||
default_seed: Optional[int] = None
|
||||
save_metadata: bool = False
|
||||
save_full_results: bool = False
|
||||
categories: list[CategoryConfig] = field(default_factory=list)
|
||||
model: str # Model identifier (e.g., "meta-llama/llama-3.3-70b-instruct")
|
||||
provider: Optional[str] = None # Provider name for OpenRouter (e.g., "Anthropic", "OpenAI")
|
||||
system_prompt: Optional[str] = None # Custom system prompt text (overrides system_prompt_id)
|
||||
system_prompt_id: Optional[str] = None # ID of predefined system prompt from SYSTEM_PROMPTS
|
||||
system_role: str = "system" # Role for the system message (usually "system")
|
||||
output_dir: str = "results" # Directory to save evaluation results
|
||||
max_concurrent: int = 10 # Maximum number of concurrent API calls
|
||||
default_size: int = 500 # Default dataset size if not specified for a dataset
|
||||
default_seed: Optional[int] = None # Default random seed if not specified for a dataset
|
||||
save_metadata: bool = False # Whether to include dataset entry metadata in results
|
||||
save_full_results: bool = False # Whether to save the full results file
|
||||
# Sampling parameters
|
||||
max_tokens: Optional[int] = 32768 # Maximum number of tokens to generate
|
||||
temperature: Optional[float] = 0.6 # Sampling temperature (higher = more random)
|
||||
top_p: Optional[float] = 0.95 # Nucleus sampling parameter (lower = more deterministic)
|
||||
|
||||
categories: list[CategoryConfig] = field(default_factory=list) # List of category configurations
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
"""Get the system prompt to use for evaluation.
|
||||
|
||||
Returns:
|
||||
The system prompt string to use
|
||||
"""
|
||||
if self.system_prompt is not None and self.system_prompt_id is not None:
|
||||
logging.warning(
|
||||
"Both system_prompt and system_prompt_id are specified in the configuration. "
|
||||
"Using system_prompt and ignoring system_prompt_id."
|
||||
)
|
||||
return self.system_prompt
|
||||
|
||||
if self.system_prompt is not None:
|
||||
return self.system_prompt
|
||||
|
||||
if self.system_prompt_id is not None:
|
||||
if self.system_prompt_id in SYSTEM_PROMPTS:
|
||||
return SYSTEM_PROMPTS[self.system_prompt_id]
|
||||
else:
|
||||
logging.warning(
|
||||
f"System prompt ID '{self.system_prompt_id}' not found in SYSTEM_PROMPTS. "
|
||||
f"Using default system prompt instead."
|
||||
)
|
||||
|
||||
# Default case: use the default system prompt
|
||||
return SYSTEM_PROMPTS["default"]
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path: str) -> "EvalConfig":
|
||||
|
|
@ -129,11 +164,16 @@ class EvalConfig:
|
|||
return cls(
|
||||
model=config_data.get("model"),
|
||||
provider=config_data.get("provider", "openai"),
|
||||
system_prompt=config_data.get("system_prompt", SYSTEM_PROMPTS["default"]),
|
||||
system_prompt=config_data.get("system_prompt"),
|
||||
system_prompt_id=config_data.get("system_prompt_id"),
|
||||
system_role=config_data.get("system_role", "system"),
|
||||
output_dir=config_data.get("output_dir", "results"),
|
||||
max_concurrent=config_data.get("max_concurrent", 10),
|
||||
save_metadata=config_data.get("save_metadata", False),
|
||||
save_full_results=config_data.get("save_full_results", False),
|
||||
# Sampling parameters
|
||||
max_tokens=config_data.get("max_tokens", 32768),
|
||||
temperature=config_data.get("temperature", 0.6),
|
||||
top_p=config_data.get("top_p", 0.95),
|
||||
categories=categories,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue