logprob wandb

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 13:33:48 -05:00
parent b6fca510d2
commit 35d4a0781b
3 changed files with 208 additions and 25 deletions

View file

@ -3,12 +3,14 @@ Data processing utilities for GRPO trainer.
Handles data retrieval from Atropos API, padding, batching,
and advantage normalization.
Also extracts inference logprobs for alignment validation with training logprobs.
"""
import json
import math
import time
from typing import List, Tuple
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -16,8 +18,16 @@ import torch
from .api import get_batch
def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
def pad_data_to_good_offset(
data: dict,
batch_size: int,
extract_inference_logprobs: bool = True,
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
Optional[List[np.ndarray]],
]:
"""
Pad and batch data from the Atropos API.
@ -26,13 +36,16 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
- Pads token sequences to nearest multiple of 64
- Normalizes advantage scores
- Extracts temperature values
- Optionally extracts inference logprobs for alignment validation
Args:
data: Raw batch data from Atropos API
batch_size: Size of each training batch
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches)
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
"""
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
@ -53,6 +66,7 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
advantages = []
lengths = []
temperatures = []
inference_logprobs_list: List[np.ndarray] = []
for item in data["batch"]:
# Normalize advantage scores
@ -97,6 +111,14 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
labels.append(label_item[1:]) # Shift by 1 for causal
advantages.append(item["scores"][i])
# Extract inference logprobs for alignment validation
# These come from vLLM during rollout generation
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
inference_logprobs_list.append(
np.array(item["inference_logprobs"][i], dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
if (
@ -137,16 +159,21 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
).view(-1, 1, 1)
)
return token_batches, label_batches, advantage_batches, temperature_batches
# Return inference logprobs if available
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
def get_data(
batch_size: int,
seq_len: int,
atropos_url: str = "http://localhost:8000",
) -> List[Tuple[
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
]]:
extract_inference_logprobs: bool = True,
) -> Tuple[
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
Optional[List[np.ndarray]],
]:
"""
Fetch and process training data from the Atropos API.
@ -157,11 +184,15 @@ def get_data(
batch_size: Size of each training batch
seq_len: Maximum sequence length (for reference, not used directly)
atropos_url: URL of the Atropos API server
extract_inference_logprobs: Whether to extract inference logprobs for alignment
Returns:
List of processed batch tuples
Tuple of (batches, all_inference_logprobs)
- batches: List of processed batch tuples
- all_inference_logprobs: List of inference logprob arrays for alignment validation
"""
batches = []
all_inference_logprobs: List[np.ndarray] = []
while True:
data = get_batch(url=atropos_url)
@ -172,10 +203,17 @@ def get_data(
json.dump(data, f)
# Process and accumulate batches
batches.append(pad_data_to_good_offset(data, batch_size))
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
batches.append((token_batches, label_batches, adv_batches, temp_batches))
if inf_logprobs:
all_inference_logprobs.extend(inf_logprobs)
elif len(batches) > 0:
# Return accumulated batches when no more data
return batches
return batches, all_inference_logprobs if all_inference_logprobs else None
else:
# Wait for data
time.sleep(1)