Eval sampling settings for generation (temperature, top-p, max_tokens) (#242)

* feat: Add sampling parameters to eval configuration and API call
* feat: Add support for system_prompt_id and optional system_prompt configuration
This commit is contained in:
Andreas Köpf 2025-02-28 11:48:37 +01:00 committed by GitHub
parent b1c8840129
commit b4207162ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 83 additions and 22 deletions

View file

@ -121,11 +121,19 @@ class AsyncModelEvaluator:
params = {
"model": self.config.model,
"messages": [
{"role": self.config.system_role, "content": self.config.system_prompt},
{"role": self.config.system_role, "content": self.config.get_system_prompt()},
{"role": "user", "content": prompt},
],
}
# Add sampling parameters if specified
if self.config.max_tokens is not None:
params["max_tokens"] = self.config.max_tokens
if self.config.temperature is not None:
params["temperature"] = self.config.temperature
if self.config.top_p is not None:
params["top_p"] = self.config.top_p
# Add provider configuration if specified
if self.config.provider:
params["extra_body"] = {"provider": {"order": [self.config.provider], "allow_fallbacks": False}}
@ -253,7 +261,7 @@ class AsyncModelEvaluator:
"average_score": average_score,
"total_examples": len(results),
"config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params},
"system_prompt": self.config.system_prompt,
"system_prompt": self.config.get_system_prompt(),
"results": results,
}
@ -265,7 +273,7 @@ class AsyncModelEvaluator:
"average_score": 0.0,
"total_examples": 0,
"config": {"size": dataset_config.size, "seed": dataset_config.seed, **dataset_config.params},
"system_prompt": self.config.system_prompt,
"system_prompt": self.config.get_system_prompt(),
"error": str(e),
"results": [],
}
@ -310,6 +318,9 @@ class AsyncModelEvaluator:
"provider": self.config.provider,
"git_hash": self.git_hash,
"duration_seconds": (datetime.now() - self.start_time).total_seconds(),
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
"top_p": self.config.top_p,
},
"categories": category_results,
}
@ -384,13 +395,18 @@ class AsyncModelEvaluator:
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
# Add timestamp, git hash, model, provider, and duration to summary
# Add timestamp, git hash, model, provider, sampling parameters, and duration to summary
summary_data = results["summary"].copy()
summary_data["timestamp"] = self.start_time.isoformat()
summary_data["git_hash"] = self.git_hash
summary_data["model"] = self.config.model
summary_data["provider"] = self.config.provider
summary_data["system_prompt"] = self.config.system_prompt
summary_data["system_prompt"] = self.config.get_system_prompt()
if self.config.system_prompt_id:
summary_data["system_prompt_id"] = self.config.system_prompt_id
summary_data["max_tokens"] = self.config.max_tokens
summary_data["temperature"] = self.config.temperature
summary_data["top_p"] = self.config.top_p
summary_data["duration_seconds"] = results["metadata"]["duration_seconds"]
# Save summary
@ -422,11 +438,11 @@ class AsyncModelEvaluator:
print("------------------")
print(f"Model: {self.config.model}")
print(f"Provider: {self.config.provider}")
print(
f"System Prompt: {self.config.system_prompt[:50]}..."
if len(self.config.system_prompt) > 50
else self.config.system_prompt
)
system_prompt = self.config.get_system_prompt()
print(f"System Prompt: {system_prompt[:50]}..." if len(system_prompt) > 50 else system_prompt)
print(f"Max Tokens: {self.config.max_tokens}")
print(f"Temperature: {self.config.temperature}")
print(f"Top-p: {self.config.top_p}")
print(f"Git Hash: {self.git_hash}")
print(f"Duration: {results['metadata']['duration_seconds']:.2f} seconds")
print()