diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index f972f7a3..c330514b 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -169,45 +169,45 @@ class ServerManager: ) -> Dict[str, Any]: """ Inject reasoning extra_body into kwargs if reasoning is configured. - + This method handles the differences between OpenAI and other providers: - OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0, and uses max_completion_tokens instead of max_tokens - Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}} - + Args: kwargs: The kwargs dict to modify server_idx: Index of the server to use for base_url detection - + Returns: Modified kwargs dict with extra_body injected if reasoning is active """ if self.reasoning_config is None or not self.reasoning_config.is_active(): return kwargs - + # Get the base_url to determine provider type base_url = self._get_server_base_url(server_idx) is_openai_official = base_url and "api.openai.com" in base_url - + # Build the extra_body for reasoning reasoning_extra_body = self.reasoning_config.build_extra_body(base_url) - + if reasoning_extra_body: # Merge with any existing extra_body in kwargs existing_extra_body = kwargs.get("extra_body", {}) or {} kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body} - + # OpenAI reasoning models have specific requirements if is_openai_official: # OpenAI reasoning models require temperature=1.0 (or unset) # Override any temperature setting kwargs["temperature"] = 1.0 - + # OpenAI reasoning models use max_completion_tokens instead of max_tokens # Convert if max_tokens is set if "max_tokens" in kwargs and kwargs["max_tokens"]: kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") - + return kwargs async def wait_for_sem(self, is_training: bool): @@ -245,7 +245,7 @@ class ServerManager: async def chat_completion(self, **kwargs) -> ChatCompletion: """ Route chat completion to the most available server. - + Automatically injects reasoning extra_body if reasoning_config is active. """ n = kwargs.get("n", 1) @@ -280,16 +280,16 @@ class ServerManager: most_available_server_num_slots = ( server.sem._value if is_train else server.eval_sem._value ) - + # Inject reasoning extra_body if configured kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - + return await self.servers[most_available_server].chat_completion(**kwargs) async def completion(self, **kwargs) -> Completion: """ Route completion to the most available server. - + Automatically injects reasoning extra_body if reasoning_config is active. """ n = kwargs.get("n", 1) @@ -322,10 +322,10 @@ class ServerManager: most_available_server_num_slots = ( server.sem._value if is_train else server.eval_sem._value ) - + # Inject reasoning extra_body if configured kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - + return await self.servers[most_available_server].completion(**kwargs) async def tokens_and_logprobs_completion( @@ -334,7 +334,7 @@ class ServerManager: """ Get tokens and logprobs from completion. Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons). - + Automatically injects reasoning extra_body if reasoning_config is active. """ n = kwargs.get("n", 1) @@ -373,10 +373,10 @@ class ServerManager: most_available_server_num_slots = ( server.sem._value if is_train else server.eval_sem._value ) - + # Inject reasoning extra_body if configured kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - + return await self.servers[most_available_server].tokens_and_logprobs_completion( **kwargs ) diff --git a/atroposlib/tests/test_reasoning_models.py b/atroposlib/tests/test_reasoning_models.py index 4a9ca7cb..e3f0ecf8 100644 --- a/atroposlib/tests/test_reasoning_models.py +++ b/atroposlib/tests/test_reasoning_models.py @@ -923,4 +923,4 @@ if __name__ == "__main__": # Run all tests run_unit_tests() asyncio.run(run_server_manager_integration_test()) - asyncio.run(run_all_integration_tests()) + asyncio.run(run_all_integration_tests()) \ No newline at end of file