manual testing

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:40:24 -05:00
parent da046d3d3b
commit c1bb4f33f0
5 changed files with 329 additions and 766 deletions

View file

@ -4,7 +4,9 @@ Data processing utilities for GRPO trainer.
Handles data retrieval from Atropos API, padding, batching,
and advantage normalization.
Also extracts inference logprobs for alignment validation with training logprobs.
Also extracts inference logprobs for proper GRPO loss computation:
- Inference logprobs serve as π_old (reference policy) for importance sampling
- They are batched and padded to align token-by-token with training labels
"""
import json
@ -23,11 +25,11 @@ def pad_data_to_good_offset(
batch_size: int,
extract_inference_logprobs: bool = True,
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
Optional[List[np.ndarray]],
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
]:
"""
Pad and batch data from the Atropos API.
@ -36,7 +38,7 @@ def pad_data_to_good_offset(
- Pads token sequences to nearest multiple of 64
- Normalizes advantage scores
- Extracts temperature values
- Optionally extracts inference logprobs for alignment validation
- Extracts and pads inference logprobs for proper GRPO loss computation
Args:
data: Raw batch data from Atropos API
@ -44,8 +46,12 @@ def pad_data_to_good_offset(
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
Note:
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
This allows token-by-token alignment during GRPO loss computation.
"""
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
@ -66,7 +72,8 @@ def pad_data_to_good_offset(
advantages = []
lengths = []
temperatures = []
inference_logprobs_list: List[np.ndarray] = []
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
has_any_logprobs = False
for item in data["batch"]:
# Normalize advantage scores
@ -84,15 +91,16 @@ def pad_data_to_good_offset(
# Process each sample in the item
for i in range(len(item["tokens"])):
seq_len = len(item["tokens"][i])
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple
math.ceil((seq_len - 1) / good_multiple) * good_multiple
)
# Create labels with padding
# Create labels with padding (-100 for masked positions)
label_item = np.concatenate([
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
@ -102,7 +110,7 @@ def pad_data_to_good_offset(
item["tokens"][i] = np.concatenate([
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])),
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
])
@ -111,13 +119,36 @@ def pad_data_to_good_offset(
labels.append(label_item[1:]) # Shift by 1 for causal
advantages.append(item["scores"][i])
# Extract inference logprobs for alignment validation
# These come from vLLM during rollout generation
# Extract and pad inference logprobs to match labels shape
# Inference logprobs are ONLY for generated tokens (where labels != -100)
# We need to create a padded array that aligns position-by-position
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
inference_logprobs_list.append(
np.array(item["inference_logprobs"][i], dtype=np.float32)
)
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
has_any_logprobs = True
# Create padded logprobs array matching label_item shape
# Fill with 0.0 (will be masked out during loss computation)
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
# The inference logprobs correspond to generated tokens
# Find positions where labels != -100 (generated positions)
mask_arr = np.array(item["masks"][i])
generated_positions = np.where(mask_arr != -100)[0]
# Fill in inference logprobs at generated positions
n_to_fill = min(len(raw_logprobs), len(generated_positions))
if n_to_fill > 0:
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
# Shift by 1 to match causal label shift
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
elif extract_inference_logprobs:
# No inference_logprobs in item, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
@ -139,6 +170,7 @@ def pad_data_to_good_offset(
label_batches = []
advantage_batches = []
temperature_batches = []
inference_logprob_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
@ -158,11 +190,17 @@ def pad_data_to_good_offset(
np.array(temperatures[start:end], dtype=np.float32)
).view(-1, 1, 1)
)
# Batch inference logprobs (same shape as labels)
if extract_inference_logprobs and inference_logprobs_padded:
inference_logprob_batches.append(
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
)
# Return inference logprobs if available
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
def get_data(
@ -171,8 +209,14 @@ def get_data(
atropos_url: str = "http://localhost:8000",
extract_inference_logprobs: bool = True,
) -> Tuple[
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
Optional[List[np.ndarray]],
List[Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]],
None, # Legacy return (no longer used)
]:
"""
Fetch and process training data from the Atropos API.
@ -184,15 +228,15 @@ def get_data(
batch_size: Size of each training batch
seq_len: Maximum sequence length (for reference, not used directly)
atropos_url: URL of the Atropos API server
extract_inference_logprobs: Whether to extract inference logprobs for alignment
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
Returns:
Tuple of (batches, all_inference_logprobs)
- batches: List of processed batch tuples
- all_inference_logprobs: List of inference logprob arrays for alignment validation
Tuple of (batches, None)
- batches: List of processed batch tuples, each containing:
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
"""
batches = []
all_inference_logprobs: List[np.ndarray] = []
while True:
data = get_batch(url=atropos_url)
@ -202,18 +246,16 @@ def get_data(
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
# Process and accumulate batches (now includes batched inference logprobs)
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
batches.append((token_batches, label_batches, adv_batches, temp_batches))
if inf_logprobs:
all_inference_logprobs.extend(inf_logprobs)
# Include inference logprob batches in the tuple
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
elif len(batches) > 0:
# Return accumulated batches when no more data
return batches, all_inference_logprobs if all_inference_logprobs else None
return batches, None
else:
# Wait for data
time.sleep(1)