Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-13 07:54:04 +10:00
commit 36f6822d71
9 changed files with 455 additions and 61 deletions

View file

@ -1,4 +1,4 @@
import torch
import numpy as np
from transformers import PreTrainedTokenizer
from atroposlib.type_definitions import Message
@ -39,7 +39,7 @@ def tokenize_for_trainer(
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
masks = torch.ones(len(tokens), dtype=torch.long) * -100
masks = np.ones(len(tokens), dtype=np.int64) * -100
for i, msg in enumerate(chat):
if msg["role"] in UNMASKED_ROLES:
@ -51,7 +51,7 @@ def tokenize_for_trainer(
)
start_idx = len(prefix_tokens)
end_idx = len(unmasked_tokens)
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:])
masks = masks.tolist()
if finish_reason == "length":