diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 23ef0b73..cbc97268 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -31,7 +31,7 @@ class SequenceNode(BaseModel): full_text: Complete text (prompt + completion) tokens: Full token sequence (actual token IDs) masked_tokens: Tokens with -100 for prompt positions, actual IDs for completion - logprobs: Logprobs with 0.0 for prompt positions, actual values for completion + logprobs: Logprobs with 1.0 for prompt positions, actual values for completion metadata: Optional metadata (e.g., role information, finish_reason, etc.) """ @@ -48,7 +48,7 @@ class ManagedServer: Maintains a tree structure keyed by input text, where each completion creates new branches. Provides proper masking for training (prompt tokens masked with -100, - logprobs set to 0.0). + logprobs set to 1.0). Uses the clean tokens_and_logprobs_completion interface internally. """ @@ -234,10 +234,10 @@ class ManagedServer: # Create masked tokens: -100 for prompt, actual IDs for completion masked_tokens = [-100] * prompt_len + output_tokens - # Create masked logprobs: 0.0 for prompt, actual for completion + # Create masked logprobs: 1.0 for prompt, actual for completion # Pad logprobs to match token length if needed if len(output_logprobs) < len(output_tokens): - output_logprobs = output_logprobs + [0.0] * ( + output_logprobs = output_logprobs + [1.0] * ( len(output_tokens) - len(output_logprobs) ) elif len(output_logprobs) > len(output_tokens):