readme fixes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-17 13:44:48 -05:00
parent 396491ab72
commit 16ac332880
4 changed files with 63 additions and 19 deletions

View file

@ -9,7 +9,6 @@ Also extracts inference logprobs for proper GRPO loss computation:
- They are batched and padded to align token-by-token with training labels
"""
import json
import math
import time
from typing import List, Optional, Tuple
@ -180,9 +179,8 @@ def pad_data_to_good_offset(
temperature_batches = []
inference_logprob_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
end = (i + 1) * batch_size
for start in range(0, len(input_ids), batch_size):
end = min(start + batch_size, len(input_ids))
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
@ -294,10 +292,6 @@ def get_data(
)
_logged_logprob_warning = True
# Save batch for debugging
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches (now includes batched inference logprobs)
(
token_batches,