sampling params

This commit is contained in:
Jai Suphavadeeprasit 2025-08-28 14:07:29 -04:00
parent 3944e7ef9b
commit 7462f45447

View file

@ -119,6 +119,16 @@ class RefusalBenchConfig(BaseEnvConfig):
description="Temperature for model evaluation completions.",
)
eval_top_p: Optional[float] = Field(
default=None,
description="Top-p (nucleus sampling) for model evaluation completions.",
)
eval_top_k: Optional[int] = Field(
default=None,
description="Top-k sampling for model evaluation completions.",
)
rollout_temperature: float = Field(
default=1.0,
description="Temperature for training rollout completions.",
@ -894,6 +904,16 @@ Your answer:"""
"temperature": self.config.eval_temperature,
"split": "eval",
}
# Add optional sampling parameters
if self.config.eval_top_p is not None:
params["top_p"] = self.config.eval_top_p
if self.config.eval_top_k is not None:
# top_k needs to be passed in extra_body for some APIs
params["extra_body"] = params.get("extra_body", {})
params["extra_body"]["top_k"] = self.config.eval_top_k
return params
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
@ -1255,6 +1275,8 @@ Your answer:"""
"end_time": end_time,
"generation_parameters": {
"temperature": self.config.eval_temperature,
"top_p": self.config.eval_top_p,
"top_k": self.config.eval_top_k,
"max_tokens": self.config.eval_max_tokens,
"thinking_mode": self.config.thinking_mode,
},