mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
manual testing
This commit is contained in:
parent
da046d3d3b
commit
c1bb4f33f0
5 changed files with 329 additions and 766 deletions
|
|
@ -4,7 +4,9 @@ Data processing utilities for GRPO trainer.
|
|||
Handles data retrieval from Atropos API, padding, batching,
|
||||
and advantage normalization.
|
||||
|
||||
Also extracts inference logprobs for alignment validation with training logprobs.
|
||||
Also extracts inference logprobs for proper GRPO loss computation:
|
||||
- Inference logprobs serve as π_old (reference policy) for importance sampling
|
||||
- They are batched and padded to align token-by-token with training labels
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -23,11 +25,11 @@ def pad_data_to_good_offset(
|
|||
batch_size: int,
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
Optional[List[np.ndarray]],
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -36,7 +38,7 @@ def pad_data_to_good_offset(
|
|||
- Pads token sequences to nearest multiple of 64
|
||||
- Normalizes advantage scores
|
||||
- Extracts temperature values
|
||||
- Optionally extracts inference logprobs for alignment validation
|
||||
- Extracts and pads inference logprobs for proper GRPO loss computation
|
||||
|
||||
Args:
|
||||
data: Raw batch data from Atropos API
|
||||
|
|
@ -44,8 +46,12 @@ def pad_data_to_good_offset(
|
|||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
|
||||
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
||||
|
||||
Note:
|
||||
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
|
||||
This allows token-by-token alignment during GRPO loss computation.
|
||||
"""
|
||||
max_token_len = max(
|
||||
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
||||
|
|
@ -66,7 +72,8 @@ def pad_data_to_good_offset(
|
|||
advantages = []
|
||||
lengths = []
|
||||
temperatures = []
|
||||
inference_logprobs_list: List[np.ndarray] = []
|
||||
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
||||
has_any_logprobs = False
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -84,15 +91,16 @@ def pad_data_to_good_offset(
|
|||
|
||||
# Process each sample in the item
|
||||
for i in range(len(item["tokens"])):
|
||||
seq_len = len(item["tokens"][i])
|
||||
lengths.append(
|
||||
math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple
|
||||
math.ceil((seq_len - 1) / good_multiple) * good_multiple
|
||||
)
|
||||
|
||||
# Create labels with padding
|
||||
# Create labels with padding (-100 for masked positions)
|
||||
label_item = np.concatenate([
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
max(0, token_setup_len - seq_len),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
|
|
@ -102,7 +110,7 @@ def pad_data_to_good_offset(
|
|||
item["tokens"][i] = np.concatenate([
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
max(0, token_setup_len - seq_len),
|
||||
dtype=np.int32,
|
||||
),
|
||||
])
|
||||
|
|
@ -111,13 +119,36 @@ def pad_data_to_good_offset(
|
|||
labels.append(label_item[1:]) # Shift by 1 for causal
|
||||
advantages.append(item["scores"][i])
|
||||
|
||||
# Extract inference logprobs for alignment validation
|
||||
# These come from vLLM during rollout generation
|
||||
# Extract and pad inference logprobs to match labels shape
|
||||
# Inference logprobs are ONLY for generated tokens (where labels != -100)
|
||||
# We need to create a padded array that aligns position-by-position
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
inference_logprobs_list.append(
|
||||
np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
)
|
||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
has_any_logprobs = True
|
||||
|
||||
# Create padded logprobs array matching label_item shape
|
||||
# Fill with 0.0 (will be masked out during loss computation)
|
||||
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
|
||||
|
||||
# The inference logprobs correspond to generated tokens
|
||||
# Find positions where labels != -100 (generated positions)
|
||||
mask_arr = np.array(item["masks"][i])
|
||||
generated_positions = np.where(mask_arr != -100)[0]
|
||||
|
||||
# Fill in inference logprobs at generated positions
|
||||
n_to_fill = min(len(raw_logprobs), len(generated_positions))
|
||||
if n_to_fill > 0:
|
||||
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
|
||||
|
||||
# Shift by 1 to match causal label shift
|
||||
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||
else:
|
||||
# No logprobs for this sample, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
elif extract_inference_logprobs:
|
||||
# No inference_logprobs in item, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
|
|
@ -139,6 +170,7 @@ def pad_data_to_good_offset(
|
|||
label_batches = []
|
||||
advantage_batches = []
|
||||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
|
||||
for i in range(len(input_ids) // batch_size):
|
||||
start = i * batch_size
|
||||
|
|
@ -158,11 +190,17 @@ def pad_data_to_good_offset(
|
|||
np.array(temperatures[start:end], dtype=np.float32)
|
||||
).view(-1, 1, 1)
|
||||
)
|
||||
|
||||
# Batch inference logprobs (same shape as labels)
|
||||
if extract_inference_logprobs and inference_logprobs_padded:
|
||||
inference_logprob_batches.append(
|
||||
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
||||
)
|
||||
|
||||
# Return inference logprobs if available
|
||||
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
|
||||
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
|
||||
|
||||
|
||||
def get_data(
|
||||
|
|
@ -171,8 +209,14 @@ def get_data(
|
|||
atropos_url: str = "http://localhost:8000",
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
|
||||
Optional[List[np.ndarray]],
|
||||
List[Tuple[
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
]],
|
||||
None, # Legacy return (no longer used)
|
||||
]:
|
||||
"""
|
||||
Fetch and process training data from the Atropos API.
|
||||
|
|
@ -184,15 +228,15 @@ def get_data(
|
|||
batch_size: Size of each training batch
|
||||
seq_len: Maximum sequence length (for reference, not used directly)
|
||||
atropos_url: URL of the Atropos API server
|
||||
extract_inference_logprobs: Whether to extract inference logprobs for alignment
|
||||
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
|
||||
|
||||
Returns:
|
||||
Tuple of (batches, all_inference_logprobs)
|
||||
- batches: List of processed batch tuples
|
||||
- all_inference_logprobs: List of inference logprob arrays for alignment validation
|
||||
Tuple of (batches, None)
|
||||
- batches: List of processed batch tuples, each containing:
|
||||
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
all_inference_logprobs: List[np.ndarray] = []
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
|
@ -202,18 +246,16 @@ def get_data(
|
|||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
# Process and accumulate batches
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
|
||||
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches))
|
||||
|
||||
if inf_logprobs:
|
||||
all_inference_logprobs.extend(inf_logprobs)
|
||||
# Include inference logprob batches in the tuple
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
|
||||
|
||||
elif len(batches) > 0:
|
||||
# Return accumulated batches when no more data
|
||||
return batches, all_inference_logprobs if all_inference_logprobs else None
|
||||
return batches, None
|
||||
else:
|
||||
# Wait for data
|
||||
time.sleep(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue