Add reasoning configuration support across server implementations

- Updated server classes (OpenAIServer, SGLangServer, TrlVllmServer, VLLMServer) to accept a ReasoningConfig parameter during initialization.
- Enhanced ReasoningConfig to allow flexible max_tokens without strict validation, accommodating varying provider limits.
- Implemented reasoning configuration injection in APIServer methods for chat and completion handling.
- Updated tests to reflect changes in max_tokens validation logic.

This commit integrates reasoning capabilities into the server handling architecture, improving compatibility with diverse reasoning models.
This commit is contained in:
teknium 2026-01-05 23:20:01 +00:00
parent 6763649c3a
commit e1ece3e64e
7 changed files with 190 additions and 116 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import inspect
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import AsyncGenerator, List, Optional, Union
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
@ -119,9 +119,15 @@ class ServerManager:
api_key="x",
)
)
self.servers = [server_class(config) for config in openai_configs]
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in openai_configs
]
elif not slurm:
self.servers = [server_class(config) for config in configs]
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
@ -134,7 +140,10 @@ class ServerManager:
"Not enough nodes to distribute to, assuming single node"
" and you've setup your sglang appropriately."
)
self.servers = [server_class(config) for config in configs]
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
return
urls = []
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
@ -149,67 +158,15 @@ class ServerManager:
new_conf = configs[0].model_copy(deep=True)
new_conf.base_url = urls[i]
new_configs.append(new_conf)
self.servers = [server_class(config) for config in new_configs]
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in new_configs
]
async def update_weight(self, weight: float):
for server in self.servers:
await server.update_weight(weight)
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
"""Get the base_url from a server's config."""
if not self.servers:
return None
server = self.servers[server_idx]
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _inject_reasoning_extra_body(
self, kwargs: Dict[str, Any], server_idx: int = 0
) -> Dict[str, Any]:
"""
Inject reasoning extra_body into kwargs if reasoning is configured.
This method handles the differences between OpenAI and other providers:
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
and uses max_completion_tokens instead of max_tokens
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
Args:
kwargs: The kwargs dict to modify
server_idx: Index of the server to use for base_url detection
Returns:
Modified kwargs dict with extra_body injected if reasoning is active
"""
if self.reasoning_config is None or not self.reasoning_config.is_active():
return kwargs
# Get the base_url to determine provider type
base_url = self._get_server_base_url(server_idx)
is_openai_official = base_url and "api.openai.com" in base_url
# Build the extra_body for reasoning
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
if reasoning_extra_body:
# Merge with any existing extra_body in kwargs
existing_extra_body = kwargs.get("extra_body", {}) or {}
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
# OpenAI reasoning models have specific requirements
if is_openai_official:
# OpenAI reasoning models require temperature=1.0 (or unset)
# Override any temperature setting
kwargs["temperature"] = 1.0
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
# Convert if max_tokens is set
if "max_tokens" in kwargs and kwargs["max_tokens"]:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
async def wait_for_sem(self, is_training: bool):
"""
Wait for a server to be available. This is used to prevent the client from
@ -246,7 +203,8 @@ class ServerManager:
"""
Route chat completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
Reasoning config injection is handled by the individual servers.
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
@ -281,16 +239,14 @@ class ServerManager:
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
"""
Route completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
Reasoning config injection is handled by the individual servers.
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
@ -323,9 +279,6 @@ class ServerManager:
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
@ -335,7 +288,8 @@ class ServerManager:
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Automatically injects reasoning extra_body if reasoning_config is active.
Note: Reasoning config is NOT injected here - this method is for extracting
raw token-level data for training, not for generating reasoned responses.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
@ -374,9 +328,6 @@ class ServerManager:
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].tokens_and_logprobs_completion(
**kwargs
)