mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
logprob wandb
This commit is contained in:
parent
b6fca510d2
commit
35d4a0781b
3 changed files with 208 additions and 25 deletions
|
|
@ -3,12 +3,14 @@ Data processing utilities for GRPO trainer.
|
|||
|
||||
Handles data retrieval from Atropos API, padding, batching,
|
||||
and advantage normalization.
|
||||
|
||||
Also extracts inference logprobs for alignment validation with training logprobs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -16,8 +18,16 @@ import torch
|
|||
from .api import get_batch
|
||||
|
||||
|
||||
def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
|
||||
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
|
||||
def pad_data_to_good_offset(
|
||||
data: dict,
|
||||
batch_size: int,
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
Optional[List[np.ndarray]],
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -26,13 +36,16 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
|
|||
- Pads token sequences to nearest multiple of 64
|
||||
- Normalizes advantage scores
|
||||
- Extracts temperature values
|
||||
- Optionally extracts inference logprobs for alignment validation
|
||||
|
||||
Args:
|
||||
data: Raw batch data from Atropos API
|
||||
batch_size: Size of each training batch
|
||||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches)
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
|
||||
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
|
||||
"""
|
||||
max_token_len = max(
|
||||
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
||||
|
|
@ -53,6 +66,7 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
|
|||
advantages = []
|
||||
lengths = []
|
||||
temperatures = []
|
||||
inference_logprobs_list: List[np.ndarray] = []
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -97,6 +111,14 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
|
|||
labels.append(label_item[1:]) # Shift by 1 for causal
|
||||
advantages.append(item["scores"][i])
|
||||
|
||||
# Extract inference logprobs for alignment validation
|
||||
# These come from vLLM during rollout generation
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
inference_logprobs_list.append(
|
||||
np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
if (
|
||||
|
|
@ -137,16 +159,21 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[
|
|||
).view(-1, 1, 1)
|
||||
)
|
||||
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches
|
||||
# Return inference logprobs if available
|
||||
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
|
||||
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
|
||||
|
||||
|
||||
def get_data(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
atropos_url: str = "http://localhost:8000",
|
||||
) -> List[Tuple[
|
||||
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
|
||||
]]:
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
|
||||
Optional[List[np.ndarray]],
|
||||
]:
|
||||
"""
|
||||
Fetch and process training data from the Atropos API.
|
||||
|
||||
|
|
@ -157,11 +184,15 @@ def get_data(
|
|||
batch_size: Size of each training batch
|
||||
seq_len: Maximum sequence length (for reference, not used directly)
|
||||
atropos_url: URL of the Atropos API server
|
||||
extract_inference_logprobs: Whether to extract inference logprobs for alignment
|
||||
|
||||
Returns:
|
||||
List of processed batch tuples
|
||||
Tuple of (batches, all_inference_logprobs)
|
||||
- batches: List of processed batch tuples
|
||||
- all_inference_logprobs: List of inference logprob arrays for alignment validation
|
||||
"""
|
||||
batches = []
|
||||
all_inference_logprobs: List[np.ndarray] = []
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
|
@ -172,10 +203,17 @@ def get_data(
|
|||
json.dump(data, f)
|
||||
|
||||
# Process and accumulate batches
|
||||
batches.append(pad_data_to_good_offset(data, batch_size))
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
|
||||
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches))
|
||||
|
||||
if inf_logprobs:
|
||||
all_inference_logprobs.extend(inf_logprobs)
|
||||
|
||||
elif len(batches) > 0:
|
||||
# Return accumulated batches when no more data
|
||||
return batches
|
||||
return batches, all_inference_logprobs if all_inference_logprobs else None
|
||||
else:
|
||||
# Wait for data
|
||||
time.sleep(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue