made masked logprobs coherently decided on

This commit is contained in:
Dakota 2025-10-29 10:52:38 -05:00
parent e57c396f86
commit d5400460e8

View file

@ -31,7 +31,7 @@ class SequenceNode(BaseModel):
full_text: Complete text (prompt + completion)
tokens: Full token sequence (actual token IDs)
masked_tokens: Tokens with -100 for prompt positions, actual IDs for completion
logprobs: Logprobs with 0.0 for prompt positions, actual values for completion
logprobs: Logprobs with 1.0 for prompt positions, actual values for completion
metadata: Optional metadata (e.g., role information, finish_reason, etc.)
"""
@ -48,7 +48,7 @@ class ManagedServer:
Maintains a tree structure keyed by input text, where each completion creates
new branches. Provides proper masking for training (prompt tokens masked with -100,
logprobs set to 0.0).
logprobs set to 1.0).
Uses the clean tokens_and_logprobs_completion interface internally.
"""
@ -234,10 +234,10 @@ class ManagedServer:
# Create masked tokens: -100 for prompt, actual IDs for completion
masked_tokens = [-100] * prompt_len + output_tokens
# Create masked logprobs: 0.0 for prompt, actual for completion
# Create masked logprobs: 1.0 for prompt, actual for completion
# Pad logprobs to match token length if needed
if len(output_logprobs) < len(output_tokens):
output_logprobs = output_logprobs + [0.0] * (
output_logprobs = output_logprobs + [1.0] * (
len(output_tokens) - len(output_logprobs)
)
elif len(output_logprobs) > len(output_tokens):