mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
readme fixes
This commit is contained in:
parent
366ea72384
commit
fae3f5b09e
4 changed files with 63 additions and 19 deletions
|
|
@ -9,7 +9,6 @@ Also extracts inference logprobs for proper GRPO loss computation:
|
|||
- They are batched and padded to align token-by-token with training labels
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
|
@ -180,9 +179,8 @@ def pad_data_to_good_offset(
|
|||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
|
||||
for i in range(len(input_ids) // batch_size):
|
||||
start = i * batch_size
|
||||
end = (i + 1) * batch_size
|
||||
for start in range(0, len(input_ids), batch_size):
|
||||
end = min(start + batch_size, len(input_ids))
|
||||
|
||||
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
|
||||
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
|
||||
|
|
@ -294,10 +292,6 @@ def get_data(
|
|||
)
|
||||
_logged_logprob_warning = True
|
||||
|
||||
# Save batch for debugging
|
||||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
(
|
||||
token_batches,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue