diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index 1e838374..6078a401 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -148,7 +148,8 @@ class SGLangServer(APIServer): kwargs.get("model", None) is not None ), "Model is required for completion!" assert ( - kwargs.get("prompt", None) is not None or kwargs.get("input_ids", None) is not None + kwargs.get("prompt", None) is not None + or kwargs.get("input_ids", None) is not None ), "Prompt or input_ids is required for completion!" # Get n parameter for number of completions diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index b8a14046..f14f6a86 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -29,9 +29,7 @@ class MockTokenizer: ] return "".join([chr(t) if t > 31 else "" for t in tokens]) - def apply_chat_template( - self, messages, tokenize=False, add_generation_prompt=True - ): + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): """Simple chat template for testing.""" result = "" for msg in messages: @@ -48,6 +46,7 @@ def mock_server(): """Create a mock server with a tokenizer.""" server = ServerHarness() server.tokenizer = MockTokenizer() + # Add config for compatibility class Config: model_name = "test_model" @@ -315,7 +314,9 @@ async def test_input_ids_extension(mock_server): # Turn 1 prompt_1 = "Hello" - prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) # [1, 72, 101, 108, 108, 111] + prompt_tokens_1 = mock_server.tokenizer.encode( + prompt_1 + ) # [1, 72, 101, 108, 108, 111] output_1 = " World" output_tokens_1 = [ord(c) for c in output_1] output_logprobs_1 = [-0.1] * len(output_tokens_1) @@ -334,7 +335,9 @@ async def test_input_ids_extension(mock_server): prompt_2 = "Hello World!" # Extends "Hello World" with "!" # The input_ids should be: existing_node_tokens + tokenize("!") node_1 = managed.current_nodes[0] - expected_input_ids = node_1.tokens + mock_server.tokenizer.encode("!", add_special_tokens=False) + expected_input_ids = node_1.tokens + mock_server.tokenizer.encode( + "!", add_special_tokens=False + ) output_2 = " Yay" output_tokens_2 = [ord(c) for c in output_2] @@ -413,7 +416,7 @@ async def test_multi_turn_chat_with_branching(mock_server): # This prompt extends turn 1's output, so input_ids should use existing tokens extending_node = state["nodes"][i] # The new part is just the user turn - new_suffix = prompt_2[len(extending_node.full_text):] + new_suffix = prompt_2[len(extending_node.full_text) :] expected_input_ids = extending_node.tokens + mock_server.tokenizer.encode( new_suffix, add_special_tokens=False ) diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 93b7fda1..239d60b8 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -382,7 +382,7 @@ class MathEnv(BaseEnv): finish_reason = item[2] # Now a clean string like "stop" or "length" # ManagedServer already provides properly formatted data tokens = item[3] # Full token sequence - masks = item[4] # Masked tokens (already formatted) + masks = item[4] # Masked tokens (already formatted) inf_logp = item[5] # Logprobs (already formatted) if finish_reason == "length": @@ -406,7 +406,9 @@ class MathEnv(BaseEnv): # remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue - if (finish_reason == "length") and (not self.config.mask_too_long_completions): + if (finish_reason == "length") and ( + not self.config.mask_too_long_completions + ): scores["overrides"][-1]["set_advantage_to_zero"] = True scores["tokens"].append(tokens) scores["masks"].append(masks)