[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -21,7 +21,7 @@ from .api import get_batch
def pad_data_to_good_offset(
data: dict,
data: dict,
batch_size: int,
extract_inference_logprobs: bool = True,
) -> Tuple[
@ -33,22 +33,22 @@ def pad_data_to_good_offset(
]:
"""
Pad and batch data from the Atropos API.
Processes raw batch data into properly padded tensors suitable for training:
- Pads token sequences to nearest multiple of 64
- Normalizes advantage scores
- Extracts temperature values
- Extracts and pads inference logprobs for proper GRPO loss computation
Args:
data: Raw batch data from Atropos API
batch_size: Size of each training batch
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
Note:
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
This allows token-by-token alignment during GRPO loss computation.
@ -56,7 +56,7 @@ def pad_data_to_good_offset(
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
)
# Pad to nearest multiple of 64 for GPU efficiency
good_multiple = 64
if (max_token_len - 1) % (good_multiple) != 0:
@ -65,7 +65,7 @@ def pad_data_to_good_offset(
else:
token_setup_len = max_token_len
max_token_len = max_token_len - 1 # -1 for causal shift
# Process all items
input_ids = []
labels = []
@ -74,7 +74,7 @@ def pad_data_to_good_offset(
temperatures = []
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
has_any_logprobs = False
for item in data["batch"]:
# Normalize advantage scores
scores = np.array(item["scores"])
@ -82,43 +82,45 @@ def pad_data_to_good_offset(
scores = scores - scores.mean()
scores = scores / max(scores.std(), 1e-8)
item["scores"] = scores
# Handle score overrides
if item["overrides"] is not None:
for i in range(len(item["overrides"])):
if item["overrides"][i].get("set_advantage_to_zero", False):
item["scores"][i] = 0
# Process each sample in the item
for i in range(len(item["tokens"])):
seq_len = len(item["tokens"][i])
lengths.append(
math.ceil((seq_len - 1) / good_multiple) * good_multiple
)
lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
# Create labels with padding (-100 for masked positions)
label_item = np.concatenate([
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
])
label_item = np.concatenate(
[
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
]
)
# Pad tokens
item["tokens"][i] = np.concatenate([
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
])
item["tokens"][i] = np.concatenate(
[
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
]
)
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
labels.append(label_item[1:]) # Shift by 1 for causal
advantages.append(item["scores"][i])
# Extract and pad inference logprobs to match labels shape
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
# - 1.0 for prompt tokens (masked positions)
@ -126,26 +128,32 @@ def pad_data_to_good_offset(
# We just need to pad to match the sequence length
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
raw_logprobs = np.array(
item["inference_logprobs"][i], dtype=np.float32
)
has_any_logprobs = True
# Create padded logprobs array matching token_setup_len
# Fill with 1.0 (the masked token placeholder value) for padding
padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32)
# Copy raw_logprobs directly - they're already aligned with tokens
n_to_copy = min(len(raw_logprobs), token_setup_len)
padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy]
# Shift by 1 to match causal label shift
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
# No logprobs for this sample, use 1.0
inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
elif extract_inference_logprobs:
# No inference_logprobs in item, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
if (
@ -155,48 +163,58 @@ def pad_data_to_good_offset(
and ("temperature" in item["overrides"][i])
):
t = float(item["overrides"][i]["temperature"])
elif item.get("generation_params") and ("temperature" in item["generation_params"]):
elif item.get("generation_params") and (
"temperature" in item["generation_params"]
):
t = float(item["generation_params"]["temperature"])
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]):
elif item.get("group_overrides") and (
"temperature" in item["group_overrides"]
):
t = float(item["group_overrides"]["temperature"])
temperatures.append(t)
# Batch the data
token_batches = []
label_batches = []
advantage_batches = []
temperature_batches = []
inference_logprob_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
end = (i + 1) * batch_size
token_batches.append(
torch.tensor(np.stack(input_ids[start:end], axis=0))
)
label_batches.append(
torch.tensor(np.stack(labels[start:end], axis=0))
)
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
advantage_batches.append(
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
)
temperature_batches.append(
torch.tensor(
np.array(temperatures[start:end], dtype=np.float32)
).view(-1, 1, 1)
torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
-1, 1, 1
)
)
# Batch inference logprobs (same shape as labels)
if extract_inference_logprobs and inference_logprobs_padded:
inference_logprob_batches.append(
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
)
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
final_logprob_batches = (
inference_logprob_batches
if (has_any_logprobs and inference_logprob_batches)
else None
)
return (
token_batches,
label_batches,
advantage_batches,
temperature_batches,
final_logprob_batches,
)
def get_data(
@ -205,27 +223,29 @@ def get_data(
atropos_url: str = "http://localhost:8000",
extract_inference_logprobs: bool = True,
) -> Tuple[
List[Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]],
List[
Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]
],
None, # Legacy return (no longer used)
]:
"""
Fetch and process training data from the Atropos API.
Continuously polls the API until data is available, then processes
all available batches.
Args:
batch_size: Size of each training batch
seq_len: Maximum sequence length (for reference, not used directly)
atropos_url: URL of the Atropos API server
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
Returns:
Tuple of (batches, None)
- batches: List of processed batch tuples, each containing:
@ -234,42 +254,73 @@ def get_data(
"""
batches = []
_logged_logprob_warning = False
while True:
data = get_batch(url=atropos_url)
if data["batch"] is not None:
# DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning:
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
has_logprobs = any(
"inference_logprobs" in item for item in data["batch"]
)
if has_logprobs:
# Check if they're non-empty
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
sample_item = next(
(
item
for item in data["batch"]
if "inference_logprobs" in item
),
None,
)
if sample_item and sample_item.get("inference_logprobs"):
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
sample_lp = (
sample_item["inference_logprobs"][0]
if sample_item["inference_logprobs"]
else []
)
print(
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
)
else:
print(" [Data] ⚠ inference_logprobs key exists but is empty!")
print(
" [Data] ⚠ inference_logprobs key exists but is empty!"
)
else:
print(" [Data] ⚠ NO inference_logprobs in batch data!")
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
print(
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
)
_logged_logprob_warning = True
# Save batch for debugging
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches (now includes batched inference logprobs)
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
(
token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
# Include inference logprob batches in the tuple
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
batches.append(
(
token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
)
)
elif len(batches) > 0:
# Return accumulated batches when no more data
return batches, None
else:
# Wait for data
time.sleep(1)