mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge branch 'main' into blackjack2-env
This commit is contained in:
commit
36f6822d71
9 changed files with 455 additions and 61 deletions
|
|
@ -1,4 +1,4 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from atroposlib.type_definitions import Message
|
||||
|
|
@ -39,7 +39,7 @@ def tokenize_for_trainer(
|
|||
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
|
||||
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
|
||||
|
||||
masks = torch.ones(len(tokens), dtype=torch.long) * -100
|
||||
masks = np.ones(len(tokens), dtype=np.int64) * -100
|
||||
|
||||
for i, msg in enumerate(chat):
|
||||
if msg["role"] in UNMASKED_ROLES:
|
||||
|
|
@ -51,7 +51,7 @@ def tokenize_for_trainer(
|
|||
)
|
||||
start_idx = len(prefix_tokens)
|
||||
end_idx = len(unmasked_tokens)
|
||||
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
|
||||
masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:])
|
||||
|
||||
masks = masks.tolist()
|
||||
if finish_reason == "length":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue