[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-10-24 20:10:25 +00:00
parent 7bf4cfbf80
commit 0d80da5146
4 changed files with 18 additions and 10 deletions

View file

@ -29,9 +29,7 @@ class MockTokenizer:
]
return "".join([chr(t) if t > 31 else "" for t in tokens])
def apply_chat_template(
self, messages, tokenize=False, add_generation_prompt=True
):
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
"""Simple chat template for testing."""
result = ""
for msg in messages:
@ -48,6 +46,7 @@ def mock_server():
"""Create a mock server with a tokenizer."""
server = ServerHarness()
server.tokenizer = MockTokenizer()
# Add config for compatibility
class Config:
model_name = "test_model"
@ -315,7 +314,9 @@ async def test_input_ids_extension(mock_server):
# Turn 1
prompt_1 = "Hello"
prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) # [1, 72, 101, 108, 108, 111]
prompt_tokens_1 = mock_server.tokenizer.encode(
prompt_1
) # [1, 72, 101, 108, 108, 111]
output_1 = " World"
output_tokens_1 = [ord(c) for c in output_1]
output_logprobs_1 = [-0.1] * len(output_tokens_1)
@ -334,7 +335,9 @@ async def test_input_ids_extension(mock_server):
prompt_2 = "Hello World!" # Extends "Hello World" with "!"
# The input_ids should be: existing_node_tokens + tokenize("!")
node_1 = managed.current_nodes[0]
expected_input_ids = node_1.tokens + mock_server.tokenizer.encode("!", add_special_tokens=False)
expected_input_ids = node_1.tokens + mock_server.tokenizer.encode(
"!", add_special_tokens=False
)
output_2 = " Yay"
output_tokens_2 = [ord(c) for c in output_2]
@ -413,7 +416,7 @@ async def test_multi_turn_chat_with_branching(mock_server):
# This prompt extends turn 1's output, so input_ids should use existing tokens
extending_node = state["nodes"][i]
# The new part is just the user turn
new_suffix = prompt_2[len(extending_node.full_text):]
new_suffix = prompt_2[len(extending_node.full_text) :]
expected_input_ids = extending_node.tokens + mock_server.tokenizer.encode(
new_suffix, add_special_tokens=False
)