mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
remove training code
This commit is contained in:
parent
862cd3667d
commit
148a4fd5eb
6 changed files with 38 additions and 329 deletions
|
|
@ -29,8 +29,6 @@ def pad_data_to_good_offset(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
||||
Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k]
|
||||
Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k]
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -47,8 +45,7 @@ def pad_data_to_good_offset(
|
|||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
inference_logprob_batches, distill_token_id_batches, distill_logprob_batches)
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
||||
|
||||
Note:
|
||||
|
|
@ -76,10 +73,6 @@ def pad_data_to_good_offset(
|
|||
temperatures = []
|
||||
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
||||
has_any_logprobs = False
|
||||
distill_token_ids_padded: List[np.ndarray] = []
|
||||
distill_logprobs_padded: List[np.ndarray] = []
|
||||
has_any_distill = False
|
||||
max_distill_k = 1
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -160,85 +153,6 @@ def pad_data_to_good_offset(
|
|||
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract teacher distillation top-k arrays if available.
|
||||
# Expected shape in incoming payload: [sequence][position][k].
|
||||
if "distill_token_ids" in item and "distill_logprobs" in item:
|
||||
seq_token_ids = item["distill_token_ids"]
|
||||
seq_logprobs = item["distill_logprobs"]
|
||||
if (
|
||||
isinstance(seq_token_ids, list)
|
||||
and isinstance(seq_logprobs, list)
|
||||
and i < len(seq_token_ids)
|
||||
and i < len(seq_logprobs)
|
||||
and seq_token_ids[i] is not None
|
||||
and seq_logprobs[i] is not None
|
||||
):
|
||||
per_pos_token_ids = seq_token_ids[i]
|
||||
per_pos_logprobs = seq_logprobs[i]
|
||||
if (
|
||||
isinstance(per_pos_token_ids, list)
|
||||
and isinstance(per_pos_logprobs, list)
|
||||
and len(per_pos_token_ids) == len(per_pos_logprobs)
|
||||
):
|
||||
local_k = 1
|
||||
for row_ids in per_pos_token_ids:
|
||||
if isinstance(row_ids, list):
|
||||
local_k = max(local_k, len(row_ids))
|
||||
max_distill_k = max(max_distill_k, local_k)
|
||||
has_any_distill = True
|
||||
|
||||
rows = max(0, token_setup_len - 1)
|
||||
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
|
||||
logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32)
|
||||
|
||||
# Shift by one to align with causal labels like inference_logprobs.
|
||||
copy_positions = min(
|
||||
len(per_pos_token_ids),
|
||||
len(per_pos_logprobs),
|
||||
token_setup_len,
|
||||
)
|
||||
for pos in range(1, copy_positions):
|
||||
src_ids = per_pos_token_ids[pos]
|
||||
src_lps = per_pos_logprobs[pos]
|
||||
if not isinstance(src_ids, list) or not isinstance(
|
||||
src_lps, list
|
||||
):
|
||||
continue
|
||||
topk = min(local_k, len(src_ids), len(src_lps))
|
||||
if topk <= 0:
|
||||
continue
|
||||
token_mat[pos - 1, :topk] = np.array(
|
||||
src_ids[:topk], dtype=np.int64
|
||||
)
|
||||
logprob_mat[pos - 1, :topk] = np.array(
|
||||
src_lps[:topk], dtype=np.float32
|
||||
)
|
||||
|
||||
distill_token_ids_padded.append(token_mat)
|
||||
distill_logprobs_padded.append(logprob_mat)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(
|
||||
np.full((rows, 1), -1, dtype=np.int64)
|
||||
)
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(
|
||||
np.full((rows, 1), -1, dtype=np.int64)
|
||||
)
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
if (
|
||||
|
|
@ -264,8 +178,6 @@ def pad_data_to_good_offset(
|
|||
advantage_batches = []
|
||||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
distill_token_id_batches = []
|
||||
distill_logprob_batches = []
|
||||
|
||||
for start in range(0, len(input_ids), batch_size):
|
||||
end = min(start + batch_size, len(input_ids))
|
||||
|
|
@ -287,46 +199,12 @@ def pad_data_to_good_offset(
|
|||
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
||||
)
|
||||
|
||||
if distill_token_ids_padded and distill_logprobs_padded:
|
||||
seq_slice_ids = distill_token_ids_padded[start:end]
|
||||
seq_slice_lps = distill_logprobs_padded[start:end]
|
||||
normalized_ids = []
|
||||
normalized_lps = []
|
||||
for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps):
|
||||
if ids_mat.shape[1] < max_distill_k:
|
||||
pad_cols = max_distill_k - ids_mat.shape[1]
|
||||
ids_mat = np.pad(
|
||||
ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1
|
||||
)
|
||||
lps_mat = np.pad(
|
||||
lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9
|
||||
)
|
||||
normalized_ids.append(ids_mat)
|
||||
normalized_lps.append(lps_mat)
|
||||
|
||||
distill_token_id_batches.append(
|
||||
torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long)
|
||||
)
|
||||
distill_logprob_batches.append(
|
||||
torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = (
|
||||
inference_logprob_batches
|
||||
if (has_any_logprobs and inference_logprob_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_token_id_batches = (
|
||||
distill_token_id_batches
|
||||
if (has_any_distill and distill_token_id_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_logprob_batches = (
|
||||
distill_logprob_batches
|
||||
if (has_any_distill and distill_logprob_batches)
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
token_batches,
|
||||
|
|
@ -334,8 +212,6 @@ def pad_data_to_good_offset(
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
final_logprob_batches,
|
||||
final_distill_token_id_batches,
|
||||
final_distill_logprob_batches,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -352,8 +228,6 @@ def get_data(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
Optional[List[torch.Tensor]], # distill_token_id_batches
|
||||
Optional[List[torch.Tensor]], # distill_logprob_batches
|
||||
]
|
||||
],
|
||||
None, # Legacy return (no longer used)
|
||||
|
|
@ -377,11 +251,47 @@ def get_data(
|
|||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
_logged_logprob_warning = False
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
||||
if data["batch"] is not None:
|
||||
# DEBUG: Check if inference_logprobs exists in the data
|
||||
if not _logged_logprob_warning:
|
||||
has_logprobs = any(
|
||||
"inference_logprobs" in item for item in data["batch"]
|
||||
)
|
||||
if has_logprobs:
|
||||
# Check if they're non-empty
|
||||
sample_item = next(
|
||||
(
|
||||
item
|
||||
for item in data["batch"]
|
||||
if "inference_logprobs" in item
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sample_item and sample_item.get("inference_logprobs"):
|
||||
sample_lp = (
|
||||
sample_item["inference_logprobs"][0]
|
||||
if sample_item["inference_logprobs"]
|
||||
else []
|
||||
)
|
||||
print(
|
||||
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" [Data] ⚠ inference_logprobs key exists but is empty!"
|
||||
)
|
||||
else:
|
||||
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
||||
print(
|
||||
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
|
||||
)
|
||||
_logged_logprob_warning = True
|
||||
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
(
|
||||
token_batches,
|
||||
|
|
@ -389,8 +299,6 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
# Include inference logprob batches in the tuple
|
||||
|
|
@ -401,8 +309,6 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue