mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
192
atroposlib/utils/tokenize_for_trainer.py
Normal file
192
atroposlib/utils/tokenize_for_trainer.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from atroposlib.type_definitions import Message
|
||||
|
||||
# Roles that should be masked in the loss calculation (not used for training)
|
||||
UNMASKED_ROLES = ["assistant"]
|
||||
|
||||
|
||||
def tokenize_for_trainer(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat: list[Message],
|
||||
include_messages: bool = False,
|
||||
train_on_all_assistant_turns: bool = False,
|
||||
finish_reason: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Tokenize a list of chat messages for the trainer.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer to use.
|
||||
chat (list): A list of chat messages.
|
||||
include_messages (bool): Whether to include the messages in the output.
|
||||
train_on_all_assistant_turns (bool): If True, mask out system/user/tool roles.
|
||||
If False, use the original prefix masking.
|
||||
Returns:
|
||||
dict: A dictionary containing the tokenized chat messages.
|
||||
"""
|
||||
|
||||
tokens = tokenizer.apply_chat_template(chat)
|
||||
|
||||
if not train_on_all_assistant_turns:
|
||||
prefix_len = len(
|
||||
tokenizer.apply_chat_template(chat[:-1], add_generation_prompt=True)
|
||||
)
|
||||
masks = [-100] * prefix_len + tokens[prefix_len:]
|
||||
else:
|
||||
# NOTE: This implementation will break if the default system prompt is used and depends on world state
|
||||
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
|
||||
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
|
||||
|
||||
masks = torch.ones(len(tokens), dtype=torch.long) * -100
|
||||
|
||||
for i, msg in enumerate(chat):
|
||||
if msg["role"] in UNMASKED_ROLES:
|
||||
prefix_tokens = tokenizer.apply_chat_template(
|
||||
chat[:i], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
unmasked_tokens = tokenizer.apply_chat_template(
|
||||
chat[: i + 1], tokenize=True
|
||||
)
|
||||
start_idx = len(prefix_tokens)
|
||||
end_idx = len(unmasked_tokens)
|
||||
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
|
||||
|
||||
masks = masks.tolist()
|
||||
if finish_reason == "length":
|
||||
if tokens[-1] == tokenizer.eos_token_id:
|
||||
print("bad token\n")
|
||||
# truncate the last token
|
||||
tokens = tokens[:-1]
|
||||
masks = masks[:-1]
|
||||
|
||||
return {
|
||||
"tokens": tokens,
|
||||
"masks": masks,
|
||||
} | ({"messages": chat} if include_messages else {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Inspired by `preprocess --debug`` of https://github.com/axolotl-ai-cloud/axolotl
|
||||
def decode_token_ids(
|
||||
token_ids: list, mask, tokenizer, use_rich: bool = False
|
||||
) -> str:
|
||||
"""Convert a list of token IDs to a formatted string using tokenizer.decode,
|
||||
with an option to highlight masked tokens in red using rich markup.
|
||||
|
||||
Each token is represented as decoded(tokenid, mask). If decoding a token returns an empty string
|
||||
and the token is a known special token, it is replaced with a descriptive placeholder.
|
||||
When use_rich is True, any token whose corresponding mask is -100 is wrapped with red highlighting.
|
||||
|
||||
Args:
|
||||
token_ids (list[int]): A list of integer token IDs,
|
||||
e.g. [50256, 329].
|
||||
mask (list[int]): A list of masks corresponding to token_ids.
|
||||
A mask value of -100 indicates the token is masked.
|
||||
tokenizer: The Hugging Face tokenizer.
|
||||
use_rich (bool): If True, wrap tokens with a mask of -100 in red highlighting.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: A space-separated string where each token is represented as decoded(tokenid, mask).
|
||||
|
||||
Raises:
|
||||
ValueError: If any element in token_ids is not an integer.
|
||||
|
||||
Example:
|
||||
>>> decode_token_ids([50256, 329], mask=[-100, 329], tokenizer=tokenizer, use_rich=True)
|
||||
'[red]<|eos|>(50256, -100)[/red] '
|
||||
'tokenX(329, 329)' # (actual output will vary based on the model's tokenizer)
|
||||
"""
|
||||
# Validate that all token_ids are integers.
|
||||
if not all(isinstance(t, int) for t in token_ids):
|
||||
raise ValueError("All token IDs must be integers.")
|
||||
|
||||
tokens_str_list = []
|
||||
for tid, mid in zip(token_ids, mask):
|
||||
# Use decode with flags to include special tokens.
|
||||
decoded = tokenizer.decode(
|
||||
[tid], skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
).strip()
|
||||
# If the decoded string is empty and it's a special token, replace with a placeholder.
|
||||
if not decoded and tid in tokenizer.all_special_ids:
|
||||
if tid == tokenizer.eos_token_id:
|
||||
decoded = "<|eos|>"
|
||||
else:
|
||||
decoded = f"<SPECIAL_{tid}>"
|
||||
|
||||
# Highlight token in red if use_rich is True and the token is masked (mid == -100)
|
||||
if use_rich:
|
||||
if mid == -100:
|
||||
token_str = f"[pink3][bold]{decoded}[/bold][/pink3][steel_blue]({tid}, {mid})[/steel_blue]"
|
||||
else:
|
||||
token_str = (
|
||||
f"[pale_green3][bold]{decoded}[/bold][/pale_green3]"
|
||||
f"[steel_blue]({tid}, {mid})[/steel_blue]"
|
||||
)
|
||||
else:
|
||||
token_str = f"{decoded}({tid}, {mid})"
|
||||
|
||||
tokens_str_list.append(token_str)
|
||||
|
||||
return " ".join(tokens_str_list)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that provides accurate information.",
|
||||
},
|
||||
{"role": "user", "content": "What's the capital of France?"},
|
||||
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||
{"role": "user", "content": "Can you tell me more about Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<tool_call>{'tool_name': 'web_search', 'args': {'query': 'Paris'}}</tool_call>",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": (
|
||||
"Paris is the capital and most populous city of France. "
|
||||
"It has an estimated population of 2,165,423 residents in 2019 "
|
||||
"in an area of more than 105 km²."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"Paris is indeed the capital of France and its most populous city with over 2 million residents. "
|
||||
"It's known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. "
|
||||
"The city is a global center for art, fashion, gastronomy, and culture."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
last_turn_only = tokenize_for_trainer(
|
||||
tokenizer, messages, train_on_all_assistant_turns=False
|
||||
)
|
||||
last_turn_only["repr"] = decode_token_ids(
|
||||
last_turn_only["tokens"], last_turn_only["masks"], tokenizer, use_rich=True
|
||||
)
|
||||
all_assistant_turns = tokenize_for_trainer(
|
||||
tokenizer, messages, train_on_all_assistant_turns=True
|
||||
)
|
||||
all_assistant_turns["repr"] = decode_token_ids(
|
||||
all_assistant_turns["tokens"],
|
||||
all_assistant_turns["masks"],
|
||||
tokenizer,
|
||||
use_rich=True,
|
||||
)
|
||||
|
||||
from rich import print
|
||||
|
||||
print("[bold cyan]last turn only[/]")
|
||||
print(last_turn_only["repr"])
|
||||
print()
|
||||
print("[bold cyan]all assistant turns[/]")
|
||||
print(all_assistant_turns["repr"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue