mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
7bf4cfbf80
commit
0d80da5146
4 changed files with 18 additions and 10 deletions
|
|
@ -29,9 +29,7 @@ class MockTokenizer:
|
|||
]
|
||||
return "".join([chr(t) if t > 31 else "" for t in tokens])
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True
|
||||
):
|
||||
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
|
||||
"""Simple chat template for testing."""
|
||||
result = ""
|
||||
for msg in messages:
|
||||
|
|
@ -48,6 +46,7 @@ def mock_server():
|
|||
"""Create a mock server with a tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizer()
|
||||
|
||||
# Add config for compatibility
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
|
@ -315,7 +314,9 @@ async def test_input_ids_extension(mock_server):
|
|||
|
||||
# Turn 1
|
||||
prompt_1 = "Hello"
|
||||
prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) # [1, 72, 101, 108, 108, 111]
|
||||
prompt_tokens_1 = mock_server.tokenizer.encode(
|
||||
prompt_1
|
||||
) # [1, 72, 101, 108, 108, 111]
|
||||
output_1 = " World"
|
||||
output_tokens_1 = [ord(c) for c in output_1]
|
||||
output_logprobs_1 = [-0.1] * len(output_tokens_1)
|
||||
|
|
@ -334,7 +335,9 @@ async def test_input_ids_extension(mock_server):
|
|||
prompt_2 = "Hello World!" # Extends "Hello World" with "!"
|
||||
# The input_ids should be: existing_node_tokens + tokenize("!")
|
||||
node_1 = managed.current_nodes[0]
|
||||
expected_input_ids = node_1.tokens + mock_server.tokenizer.encode("!", add_special_tokens=False)
|
||||
expected_input_ids = node_1.tokens + mock_server.tokenizer.encode(
|
||||
"!", add_special_tokens=False
|
||||
)
|
||||
|
||||
output_2 = " Yay"
|
||||
output_tokens_2 = [ord(c) for c in output_2]
|
||||
|
|
@ -413,7 +416,7 @@ async def test_multi_turn_chat_with_branching(mock_server):
|
|||
# This prompt extends turn 1's output, so input_ids should use existing tokens
|
||||
extending_node = state["nodes"][i]
|
||||
# The new part is just the user turn
|
||||
new_suffix = prompt_2[len(extending_node.full_text):]
|
||||
new_suffix = prompt_2[len(extending_node.full_text) :]
|
||||
expected_input_ids = extending_node.tokens + mock_server.tokenizer.encode(
|
||||
new_suffix, add_special_tokens=False
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue