mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
change OPD style
This commit is contained in:
parent
33f5696171
commit
527433b5bc
10 changed files with 452 additions and 90 deletions
|
|
@ -145,9 +145,10 @@ class ScoredData(BaseModel):
|
|||
group_overrides: Optional[dict] = None
|
||||
images: Optional[Any] = None
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
# On-policy distillation: top-K logprobs from teacher model
|
||||
# Structure: [sequence][position][top_k] = [token_id, logprob]
|
||||
onpolicydistill_logprobs: Optional[List[List[List[List]]]] = None
|
||||
# On-policy distillation (new format): parallel token ids + logprobs.
|
||||
# Shape for both: [sequence][position][top_k]
|
||||
distill_token_ids: Optional[List[List[List[int]]]] = None
|
||||
distill_logprobs: Optional[List[List[List[float]]]] = None
|
||||
|
||||
@field_validator("messages", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -185,7 +186,8 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
|
|||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
"env_id": scored_data.env_id,
|
||||
"onpolicydistill_logprobs": scored_data.onpolicydistill_logprobs,
|
||||
"distill_token_ids": scored_data.distill_token_ids,
|
||||
"distill_logprobs": scored_data.distill_logprobs,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue