atropos/example_trainer/data.py
Jai Suphavadeeprasit 6833d4d820 major refactor
2026-03-02 11:18:52 -05:00

182 lines
5.8 KiB
Python

"""
Data processing utilities for GRPO trainer.
Handles data retrieval from Atropos API, padding, batching,
and advantage normalization.
"""
import json
import math
import time
from typing import List, Tuple
import numpy as np
import torch
from .api import get_batch
def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
]:
"""
Pad and batch data from the Atropos API.
Processes raw batch data into properly padded tensors suitable for training:
- Pads token sequences to nearest multiple of 64
- Normalizes advantage scores
- Extracts temperature values
Args:
data: Raw batch data from Atropos API
batch_size: Size of each training batch
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches)
"""
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
)
# Pad to nearest multiple of 64 for GPU efficiency
good_multiple = 64
if (max_token_len - 1) % (good_multiple) != 0:
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
token_setup_len = max_token_len + 1 # +1 for causal shift
else:
token_setup_len = max_token_len
max_token_len = max_token_len - 1 # -1 for causal shift
# Process all items
input_ids = []
labels = []
advantages = []
lengths = []
temperatures = []
for item in data["batch"]:
# Normalize advantage scores
scores = np.array(item["scores"])
if len(scores) > 1:
scores = scores - scores.mean()
scores = scores / max(scores.std(), 1e-8)
item["scores"] = scores
# Handle score overrides
if item["overrides"] is not None:
for i in range(len(item["overrides"])):
if item["overrides"][i].get("set_advantage_to_zero", False):
item["scores"][i] = 0
# Process each sample in the item
for i in range(len(item["tokens"])):
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple
)
# Create labels with padding
label_item = np.concatenate([
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
-100,
dtype=np.int32,
),
])
# Pad tokens
item["tokens"][i] = np.concatenate([
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])),
dtype=np.int32,
),
])
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
labels.append(label_item[1:]) # Shift by 1 for causal
advantages.append(item["scores"][i])
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
if (
item.get("overrides")
and i < len(item["overrides"])
and isinstance(item["overrides"][i], dict)
and ("temperature" in item["overrides"][i])
):
t = float(item["overrides"][i]["temperature"])
elif item.get("generation_params") and ("temperature" in item["generation_params"]):
t = float(item["generation_params"]["temperature"])
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]):
t = float(item["group_overrides"]["temperature"])
temperatures.append(t)
# Batch the data
token_batches = []
label_batches = []
advantage_batches = []
temperature_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
end = (i + 1) * batch_size
token_batches.append(
torch.tensor(np.stack(input_ids[start:end], axis=0))
)
label_batches.append(
torch.tensor(np.stack(labels[start:end], axis=0))
)
advantage_batches.append(
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
)
temperature_batches.append(
torch.tensor(
np.array(temperatures[start:end], dtype=np.float32)
).view(-1, 1, 1)
)
return token_batches, label_batches, advantage_batches, temperature_batches
def get_data(
batch_size: int,
seq_len: int,
atropos_url: str = "http://localhost:8000",
) -> List[Tuple[
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
]]:
"""
Fetch and process training data from the Atropos API.
Continuously polls the API until data is available, then processes
all available batches.
Args:
batch_size: Size of each training batch
seq_len: Maximum sequence length (for reference, not used directly)
atropos_url: URL of the Atropos API server
Returns:
List of processed batch tuples
"""
batches = []
while True:
data = get_batch(url=atropos_url)
if data["batch"] is not None:
# Save batch for debugging
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches
batches.append(pad_data_to_good_offset(data, batch_size))
elif len(batches) > 0:
# Return accumulated batches when no more data
return batches
else:
# Wait for data
time.sleep(1)