mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
readd multistep masking
This commit is contained in:
parent
4d0f919fd1
commit
7fe1a40368
3 changed files with 109 additions and 6 deletions
|
|
@ -7,6 +7,106 @@ from atroposlib.type_definitions import Message
|
|||
UNMASKED_ROLES = ["assistant", "agent"]
|
||||
|
||||
|
||||
def tokenize_for_trainer_multistep(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat: list[Message],
|
||||
include_messages: bool = False,
|
||||
finish_reason: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Tokenizes a list of chat messages for training in a multistep RL environment.
|
||||
|
||||
This function is specifically designed for scenarios where previous assistant messages
|
||||
in the chat history might have been truncated or summarized to manage context length.
|
||||
To ensure the model learns from high-quality, complete responses, this function
|
||||
implements a specific masking strategy:
|
||||
- Only the content of the *last* message with a role in UNMASKED_ROLES (e.g., 'assistant', 'agent')
|
||||
is unmasked for loss calculation (i.e., its labels will be the token IDs).
|
||||
- All other parts of the chat, including system prompts, user messages, tool calls/responses,
|
||||
and any prior assistant/agent messages (which might be modified), are masked out (labels set to -100).
|
||||
|
||||
This approach prevents the model from being trained on potentially noisy or incomplete
|
||||
data from summarized/truncated turns, focusing the learning signal on the final,
|
||||
presumably complete, assistant generation in the sequence.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer to use.
|
||||
chat (list[Message]): A list of chat messages. Previous assistant messages may
|
||||
be truncated or summarized.
|
||||
include_messages (bool): Whether to include the original `chat` messages in the output dict.
|
||||
Defaults to False.
|
||||
finish_reason (str): Optional string indicating the reason generation finished.
|
||||
If "length", the last token might be truncated if it's an EOS token,
|
||||
as this can be an artifact of hitting max length rather than a deliberate stop.
|
||||
Defaults to "".
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- "tokens" (list[int]): The tokenized IDs for the entire chat.
|
||||
- "masks" (list[int]): The labels for training. Token IDs for the last assistant
|
||||
message's content, -100 otherwise.
|
||||
- "messages" (list[Message], optional): The input chat, if `include_messages` is True.
|
||||
"""
|
||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True)
|
||||
if not isinstance(input_ids, list): # Ensure it's a list for consistency
|
||||
input_ids = input_ids.tolist()
|
||||
|
||||
labels = torch.ones(len(input_ids), dtype=torch.long) * -100
|
||||
|
||||
last_unmasked_message_idx = -1
|
||||
for i in range(len(chat) - 1, -1, -1):
|
||||
if chat[i]["role"] in UNMASKED_ROLES:
|
||||
last_unmasked_message_idx = i
|
||||
break
|
||||
|
||||
if last_unmasked_message_idx != -1:
|
||||
# Determine the token span for the content of chat[last_unmasked_message_idx]
|
||||
|
||||
# Tokens of all messages *before* chat[last_unmasked_message_idx],
|
||||
# plus the role prompt for it.
|
||||
# `add_generation_prompt=True` prepares the template for chat[last_unmasked_message_idx] to start.
|
||||
tokens_before_target_message_content_starts = tokenizer.apply_chat_template(
|
||||
chat[:last_unmasked_message_idx], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
if not isinstance(tokens_before_target_message_content_starts, list):
|
||||
tokens_before_target_message_content_starts = (
|
||||
tokens_before_target_message_content_starts.tolist()
|
||||
)
|
||||
|
||||
# Tokens of all messages *up to and including* chat[last_unmasked_message_idx].
|
||||
# `add_generation_prompt=False` (default) ensures no *extra* prompt after it.
|
||||
tokens_up_to_target_message_content_ends = tokenizer.apply_chat_template(
|
||||
chat[: last_unmasked_message_idx + 1], tokenize=True
|
||||
)
|
||||
if not isinstance(tokens_up_to_target_message_content_ends, list):
|
||||
tokens_up_to_target_message_content_ends = (
|
||||
tokens_up_to_target_message_content_ends.tolist()
|
||||
)
|
||||
|
||||
start_idx = len(tokens_before_target_message_content_starts)
|
||||
end_idx = len(tokens_up_to_target_message_content_ends)
|
||||
|
||||
if 0 <= start_idx < end_idx <= len(input_ids):
|
||||
actual_token_ids_for_label = torch.tensor(
|
||||
input_ids[start_idx:end_idx], dtype=torch.long
|
||||
)
|
||||
labels[start_idx:end_idx] = actual_token_ids_for_label
|
||||
else:
|
||||
pass
|
||||
|
||||
final_labels = labels.tolist()
|
||||
|
||||
if finish_reason == "length":
|
||||
if input_ids and input_ids[-1] == tokenizer.eos_token_id:
|
||||
input_ids = input_ids[:-1]
|
||||
final_labels = final_labels[:-1]
|
||||
|
||||
return {
|
||||
"tokens": input_ids,
|
||||
"masks": final_labels, # "masks" is used for labels in this context
|
||||
} | ({"messages": chat} if include_messages else {})
|
||||
|
||||
|
||||
def tokenize_for_trainer(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat: list[Message],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue