readd multistep masking

This commit is contained in:
Shannon Sands 2025-05-10 09:24:55 +10:00
parent 4d0f919fd1
commit 7fe1a40368
3 changed files with 109 additions and 6 deletions

View file

@ -7,6 +7,106 @@ from atroposlib.type_definitions import Message
UNMASKED_ROLES = ["assistant", "agent"]
def tokenize_for_trainer_multistep(
tokenizer: PreTrainedTokenizer,
chat: list[Message],
include_messages: bool = False,
finish_reason: str = "",
) -> dict:
"""
Tokenizes a list of chat messages for training in a multistep RL environment.
This function is specifically designed for scenarios where previous assistant messages
in the chat history might have been truncated or summarized to manage context length.
To ensure the model learns from high-quality, complete responses, this function
implements a specific masking strategy:
- Only the content of the *last* message with a role in UNMASKED_ROLES (e.g., 'assistant', 'agent')
is unmasked for loss calculation (i.e., its labels will be the token IDs).
- All other parts of the chat, including system prompts, user messages, tool calls/responses,
and any prior assistant/agent messages (which might be modified), are masked out (labels set to -100).
This approach prevents the model from being trained on potentially noisy or incomplete
data from summarized/truncated turns, focusing the learning signal on the final,
presumably complete, assistant generation in the sequence.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use.
chat (list[Message]): A list of chat messages. Previous assistant messages may
be truncated or summarized.
include_messages (bool): Whether to include the original `chat` messages in the output dict.
Defaults to False.
finish_reason (str): Optional string indicating the reason generation finished.
If "length", the last token might be truncated if it's an EOS token,
as this can be an artifact of hitting max length rather than a deliberate stop.
Defaults to "".
Returns:
dict: A dictionary containing:
- "tokens" (list[int]): The tokenized IDs for the entire chat.
- "masks" (list[int]): The labels for training. Token IDs for the last assistant
message's content, -100 otherwise.
- "messages" (list[Message], optional): The input chat, if `include_messages` is True.
"""
input_ids = tokenizer.apply_chat_template(chat, tokenize=True)
if not isinstance(input_ids, list): # Ensure it's a list for consistency
input_ids = input_ids.tolist()
labels = torch.ones(len(input_ids), dtype=torch.long) * -100
last_unmasked_message_idx = -1
for i in range(len(chat) - 1, -1, -1):
if chat[i]["role"] in UNMASKED_ROLES:
last_unmasked_message_idx = i
break
if last_unmasked_message_idx != -1:
# Determine the token span for the content of chat[last_unmasked_message_idx]
# Tokens of all messages *before* chat[last_unmasked_message_idx],
# plus the role prompt for it.
# `add_generation_prompt=True` prepares the template for chat[last_unmasked_message_idx] to start.
tokens_before_target_message_content_starts = tokenizer.apply_chat_template(
chat[:last_unmasked_message_idx], tokenize=True, add_generation_prompt=True
)
if not isinstance(tokens_before_target_message_content_starts, list):
tokens_before_target_message_content_starts = (
tokens_before_target_message_content_starts.tolist()
)
# Tokens of all messages *up to and including* chat[last_unmasked_message_idx].
# `add_generation_prompt=False` (default) ensures no *extra* prompt after it.
tokens_up_to_target_message_content_ends = tokenizer.apply_chat_template(
chat[: last_unmasked_message_idx + 1], tokenize=True
)
if not isinstance(tokens_up_to_target_message_content_ends, list):
tokens_up_to_target_message_content_ends = (
tokens_up_to_target_message_content_ends.tolist()
)
start_idx = len(tokens_before_target_message_content_starts)
end_idx = len(tokens_up_to_target_message_content_ends)
if 0 <= start_idx < end_idx <= len(input_ids):
actual_token_ids_for_label = torch.tensor(
input_ids[start_idx:end_idx], dtype=torch.long
)
labels[start_idx:end_idx] = actual_token_ids_for_label
else:
pass
final_labels = labels.tolist()
if finish_reason == "length":
if input_ids and input_ids[-1] == tokenizer.eos_token_id:
input_ids = input_ids[:-1]
final_labels = final_labels[:-1]
return {
"tokens": input_ids,
"masks": final_labels, # "masks" is used for labels in this context
} | ({"messages": chat} if include_messages else {})
def tokenize_for_trainer(
tokenizer: PreTrainedTokenizer,
chat: list[Message],