mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
sampling params
This commit is contained in:
parent
3944e7ef9b
commit
7462f45447
1 changed files with 22 additions and 0 deletions
|
|
@ -119,6 +119,16 @@ class RefusalBenchConfig(BaseEnvConfig):
|
|||
description="Temperature for model evaluation completions.",
|
||||
)
|
||||
|
||||
eval_top_p: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Top-p (nucleus sampling) for model evaluation completions.",
|
||||
)
|
||||
|
||||
eval_top_k: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Top-k sampling for model evaluation completions.",
|
||||
)
|
||||
|
||||
rollout_temperature: float = Field(
|
||||
default=1.0,
|
||||
description="Temperature for training rollout completions.",
|
||||
|
|
@ -894,6 +904,16 @@ Your answer:"""
|
|||
"temperature": self.config.eval_temperature,
|
||||
"split": "eval",
|
||||
}
|
||||
|
||||
# Add optional sampling parameters
|
||||
if self.config.eval_top_p is not None:
|
||||
params["top_p"] = self.config.eval_top_p
|
||||
|
||||
if self.config.eval_top_k is not None:
|
||||
# top_k needs to be passed in extra_body for some APIs
|
||||
params["extra_body"] = params.get("extra_body", {})
|
||||
params["extra_body"]["top_k"] = self.config.eval_top_k
|
||||
|
||||
return params
|
||||
|
||||
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -1255,6 +1275,8 @@ Your answer:"""
|
|||
"end_time": end_time,
|
||||
"generation_parameters": {
|
||||
"temperature": self.config.eval_temperature,
|
||||
"top_p": self.config.eval_top_p,
|
||||
"top_k": self.config.eval_top_k,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
"thinking_mode": self.config.thinking_mode,
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue