Add reasoning configuration support across server implementations

- Updated server classes (OpenAIServer, SGLangServer, TrlVllmServer, VLLMServer) to accept a ReasoningConfig parameter during initialization.
- Enhanced ReasoningConfig to allow flexible max_tokens without strict validation, accommodating varying provider limits.
- Implemented reasoning configuration injection in APIServer methods for chat and completion handling.
- Updated tests to reflect changes in max_tokens validation logic.

This commit integrates reasoning capabilities into the server handling architecture, improving compatibility with diverse reasoning models.
This commit is contained in:
teknium 2026-01-05 23:20:01 +00:00
parent 6763649c3a
commit e1ece3e64e
7 changed files with 190 additions and 116 deletions

View file

@ -29,7 +29,9 @@ class ReasoningConfig:
max_tokens are specified.
effort: Reasoning effort level. One of: "none", "minimal", "low", "medium",
"high", "xhigh". Default None (not specified).
max_tokens: Maximum tokens for reasoning (1024-32000). Default None.
max_tokens: Maximum tokens for reasoning. No validation enforced - provider
limits vary (e.g., OpenRouter currently caps Anthropic at 1024-32000,
but native Anthropic supports up to 128k). Default None.
"""
enabled: bool = False
@ -45,13 +47,10 @@ class ReasoningConfig:
f"Must be one of: {VALID_REASONING_EFFORTS}"
)
# Validate max_tokens range if provided
if self.max_tokens is not None:
if self.max_tokens < 1024 or self.max_tokens > 32000:
raise ValueError(
f"max_reasoning_tokens must be between 1024 and 32000, "
f"got {self.max_tokens}"
)
# Note: As of 2024, OpenRouter caps Anthropic reasoning tokens at 1024-32000
# See: https://openrouter.ai/docs/guides/best-practices/reasoning-tokens
# However, we don't enforce this limit here since providers may extend ranges
# (e.g., Anthropic's latest models support up to 128k extended thinking)
# Auto-enable if effort or max_tokens are specified
if self.effort is not None or self.max_tokens is not None:
@ -61,21 +60,38 @@ class ReasoningConfig:
"""Check if reasoning is active (enabled with any settings)."""
return self.enabled
# Mapping from effort levels to approximate max_tokens values
# Based on OpenRouter's effort-to-budget_tokens formula percentages:
# https://openrouter.ai/docs/guides/best-practices/reasoning-tokens
# Calculated as percentage of 32k base: none=min, minimal=10%, low=20%,
# medium=50%, high=80%, xhigh=95%
EFFORT_TO_MAX_TOKENS = {
"none": 1024, # Minimum/disabled
"minimal": 3200, # ~10% of 32k
"low": 6400, # ~20% of 32k
"medium": 16000, # ~50% of 32k
"high": 25600, # ~80% of 32k
"xhigh": 30400, # ~95% of 32k
}
def build_extra_body(
self, base_url: Optional[str] = None
self, base_url: Optional[str] = None, use_max_tokens: bool = False
) -> Optional[Dict[str, Any]]:
"""
Build the extra_body dict for API requests based on provider.
Args:
base_url: The API base URL, used to detect OpenAI official endpoint.
use_max_tokens: If True, convert effort levels to max_tokens values
instead of passing effort strings. Useful for providers
that only support token-based reasoning limits.
Returns:
Dict to merge into extra_body, or None if reasoning not active.
Note:
OpenRouter only allows ONE of effort or max_tokens, not both.
When both are specified, effort takes priority.
When both are specified, effort takes priority (unless use_max_tokens=True).
"""
if not self.is_active():
return None
@ -99,14 +115,20 @@ class ReasoningConfig:
return {"reasoning_effort": openai_effort_map.get(effort, "medium")}
else:
# Standard format for OpenRouter, Nebius, Nous Portal, etc.
# Note: OpenRouter only allows ONE of effort or max_tokens, not both.
# When both are specified, effort takes priority.
reasoning = {"enabled": True}
if self.effort is not None:
# If use_max_tokens is True, convert effort to max_tokens
if use_max_tokens and self.effort is not None:
reasoning["max_tokens"] = self.EFFORT_TO_MAX_TOKENS.get(
self.effort, 8192
)
elif self.effort is not None:
# Pass effort string directly (provider may or may not support it)
reasoning["effort"] = self.effort
elif self.max_tokens is not None:
# Only add max_tokens if effort is not specified
# Use explicit max_tokens if provided
reasoning["max_tokens"] = self.max_tokens
return {"reasoning": reasoning}
@classmethod
@ -263,8 +285,13 @@ class APIServer(ABC):
Abstract class for API servers.
"""
def __init__(self, config: APIServerConfig):
def __init__(
self,
config: APIServerConfig,
reasoning_config: Optional[ReasoningConfig] = None,
):
self.config = config
self.reasoning_config = reasoning_config
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
self.server_healthy = True
@ -276,6 +303,53 @@ class APIServer(ABC):
self.check_task = None
self.initialized = False
def _inject_reasoning_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Inject reasoning configuration into kwargs if reasoning is enabled.
This method can be overridden by subclasses to handle implementation-specific
quirks for different server types (vLLM, SGLang, OpenAI, etc.).
The caller can pass `skip_reasoning=True` in kwargs to bypass injection.
Args:
kwargs: The kwargs dict to potentially modify
Returns:
Modified kwargs dict with reasoning config injected (if applicable)
"""
# Check if caller explicitly wants to skip reasoning injection
skip_reasoning = kwargs.pop("skip_reasoning", False)
if skip_reasoning:
return kwargs
# Check if reasoning is configured and active
if self.reasoning_config is None or not self.reasoning_config.is_active():
return kwargs
# Get base_url to determine provider type
base_url = getattr(self.config, "base_url", None)
is_openai_official = base_url and "api.openai.com" in base_url
# Build the extra_body for reasoning
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
if reasoning_extra_body:
# Merge with any existing extra_body in kwargs
existing_extra_body = kwargs.get("extra_body", {}) or {}
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
# OpenAI reasoning models have specific requirements
if is_openai_official:
# OpenAI reasoning models require temperature=1.0 (or unset)
kwargs["temperature"] = 1.0
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
if "max_tokens" in kwargs and kwargs["max_tokens"]:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
async def update_weight(self, weight: float) -> None:
"""
Update the weight of the semaphores
@ -397,6 +471,9 @@ class APIServer(ABC):
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
Automatically injects reasoning config if configured. Pass `skip_reasoning=True`
to bypass reasoning injection for this specific call.
"""
if not self.initialized:
if self.config.health_check:
@ -413,6 +490,10 @@ class APIServer(ABC):
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
# Inject reasoning config if enabled (can be skipped via skip_reasoning=True)
kwargs = self._inject_reasoning_kwargs(kwargs)
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
@ -463,6 +544,9 @@ class APIServer(ABC):
async def completion(self, **kwargs) -> Completion:
"""
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
Automatically injects reasoning config if configured. Pass `skip_reasoning=True`
to bypass reasoning injection for this specific call.
"""
if not self.initialized:
if self.config.health_check:
@ -480,6 +564,10 @@ class APIServer(ABC):
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
# Inject reasoning config if enabled (can be skipped via skip_reasoning=True)
kwargs = self._inject_reasoning_kwargs(kwargs)
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":