mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Add reasoning configuration support across server implementations
- Updated server classes (OpenAIServer, SGLangServer, TrlVllmServer, VLLMServer) to accept a ReasoningConfig parameter during initialization. - Enhanced ReasoningConfig to allow flexible max_tokens without strict validation, accommodating varying provider limits. - Implemented reasoning configuration injection in APIServer methods for chat and completion handling. - Updated tests to reflect changes in max_tokens validation logic. This commit integrates reasoning capabilities into the server handling architecture, improving compatibility with diverse reasoning models.
This commit is contained in:
parent
6763649c3a
commit
e1ece3e64e
7 changed files with 190 additions and 116 deletions
|
|
@ -8,7 +8,11 @@ from openai.types.completion import Completion
|
|||
from pydantic_cli import FailedExecutionException
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServer(APIServer):
|
||||
|
|
@ -16,13 +20,17 @@ class OpenAIServer(APIServer):
|
|||
OpenAI server handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
reasoning_config: ReasoningConfig = None,
|
||||
):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
super().__init__(config)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ class ReasoningConfig:
|
|||
max_tokens are specified.
|
||||
effort: Reasoning effort level. One of: "none", "minimal", "low", "medium",
|
||||
"high", "xhigh". Default None (not specified).
|
||||
max_tokens: Maximum tokens for reasoning (1024-32000). Default None.
|
||||
max_tokens: Maximum tokens for reasoning. No validation enforced - provider
|
||||
limits vary (e.g., OpenRouter currently caps Anthropic at 1024-32000,
|
||||
but native Anthropic supports up to 128k). Default None.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
|
|
@ -45,13 +47,10 @@ class ReasoningConfig:
|
|||
f"Must be one of: {VALID_REASONING_EFFORTS}"
|
||||
)
|
||||
|
||||
# Validate max_tokens range if provided
|
||||
if self.max_tokens is not None:
|
||||
if self.max_tokens < 1024 or self.max_tokens > 32000:
|
||||
raise ValueError(
|
||||
f"max_reasoning_tokens must be between 1024 and 32000, "
|
||||
f"got {self.max_tokens}"
|
||||
)
|
||||
# Note: As of 2024, OpenRouter caps Anthropic reasoning tokens at 1024-32000
|
||||
# See: https://openrouter.ai/docs/guides/best-practices/reasoning-tokens
|
||||
# However, we don't enforce this limit here since providers may extend ranges
|
||||
# (e.g., Anthropic's latest models support up to 128k extended thinking)
|
||||
|
||||
# Auto-enable if effort or max_tokens are specified
|
||||
if self.effort is not None or self.max_tokens is not None:
|
||||
|
|
@ -61,21 +60,38 @@ class ReasoningConfig:
|
|||
"""Check if reasoning is active (enabled with any settings)."""
|
||||
return self.enabled
|
||||
|
||||
# Mapping from effort levels to approximate max_tokens values
|
||||
# Based on OpenRouter's effort-to-budget_tokens formula percentages:
|
||||
# https://openrouter.ai/docs/guides/best-practices/reasoning-tokens
|
||||
# Calculated as percentage of 32k base: none=min, minimal=10%, low=20%,
|
||||
# medium=50%, high=80%, xhigh=95%
|
||||
EFFORT_TO_MAX_TOKENS = {
|
||||
"none": 1024, # Minimum/disabled
|
||||
"minimal": 3200, # ~10% of 32k
|
||||
"low": 6400, # ~20% of 32k
|
||||
"medium": 16000, # ~50% of 32k
|
||||
"high": 25600, # ~80% of 32k
|
||||
"xhigh": 30400, # ~95% of 32k
|
||||
}
|
||||
|
||||
def build_extra_body(
|
||||
self, base_url: Optional[str] = None
|
||||
self, base_url: Optional[str] = None, use_max_tokens: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Build the extra_body dict for API requests based on provider.
|
||||
|
||||
Args:
|
||||
base_url: The API base URL, used to detect OpenAI official endpoint.
|
||||
use_max_tokens: If True, convert effort levels to max_tokens values
|
||||
instead of passing effort strings. Useful for providers
|
||||
that only support token-based reasoning limits.
|
||||
|
||||
Returns:
|
||||
Dict to merge into extra_body, or None if reasoning not active.
|
||||
|
||||
Note:
|
||||
OpenRouter only allows ONE of effort or max_tokens, not both.
|
||||
When both are specified, effort takes priority.
|
||||
When both are specified, effort takes priority (unless use_max_tokens=True).
|
||||
"""
|
||||
if not self.is_active():
|
||||
return None
|
||||
|
|
@ -99,14 +115,20 @@ class ReasoningConfig:
|
|||
return {"reasoning_effort": openai_effort_map.get(effort, "medium")}
|
||||
else:
|
||||
# Standard format for OpenRouter, Nebius, Nous Portal, etc.
|
||||
# Note: OpenRouter only allows ONE of effort or max_tokens, not both.
|
||||
# When both are specified, effort takes priority.
|
||||
reasoning = {"enabled": True}
|
||||
if self.effort is not None:
|
||||
|
||||
# If use_max_tokens is True, convert effort to max_tokens
|
||||
if use_max_tokens and self.effort is not None:
|
||||
reasoning["max_tokens"] = self.EFFORT_TO_MAX_TOKENS.get(
|
||||
self.effort, 8192
|
||||
)
|
||||
elif self.effort is not None:
|
||||
# Pass effort string directly (provider may or may not support it)
|
||||
reasoning["effort"] = self.effort
|
||||
elif self.max_tokens is not None:
|
||||
# Only add max_tokens if effort is not specified
|
||||
# Use explicit max_tokens if provided
|
||||
reasoning["max_tokens"] = self.max_tokens
|
||||
|
||||
return {"reasoning": reasoning}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -263,8 +285,13 @@ class APIServer(ABC):
|
|||
Abstract class for API servers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.reasoning_config = reasoning_config
|
||||
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
|
||||
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
|
||||
self.server_healthy = True
|
||||
|
|
@ -276,6 +303,53 @@ class APIServer(ABC):
|
|||
self.check_task = None
|
||||
self.initialized = False
|
||||
|
||||
def _inject_reasoning_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Inject reasoning configuration into kwargs if reasoning is enabled.
|
||||
|
||||
This method can be overridden by subclasses to handle implementation-specific
|
||||
quirks for different server types (vLLM, SGLang, OpenAI, etc.).
|
||||
|
||||
The caller can pass `skip_reasoning=True` in kwargs to bypass injection.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict to potentially modify
|
||||
|
||||
Returns:
|
||||
Modified kwargs dict with reasoning config injected (if applicable)
|
||||
"""
|
||||
# Check if caller explicitly wants to skip reasoning injection
|
||||
skip_reasoning = kwargs.pop("skip_reasoning", False)
|
||||
if skip_reasoning:
|
||||
return kwargs
|
||||
|
||||
# Check if reasoning is configured and active
|
||||
if self.reasoning_config is None or not self.reasoning_config.is_active():
|
||||
return kwargs
|
||||
|
||||
# Get base_url to determine provider type
|
||||
base_url = getattr(self.config, "base_url", None)
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
# Build the extra_body for reasoning
|
||||
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
|
||||
|
||||
if reasoning_extra_body:
|
||||
# Merge with any existing extra_body in kwargs
|
||||
existing_extra_body = kwargs.get("extra_body", {}) or {}
|
||||
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
|
||||
|
||||
# OpenAI reasoning models have specific requirements
|
||||
if is_openai_official:
|
||||
# OpenAI reasoning models require temperature=1.0 (or unset)
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
|
||||
if "max_tokens" in kwargs and kwargs["max_tokens"]:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
|
||||
async def update_weight(self, weight: float) -> None:
|
||||
"""
|
||||
Update the weight of the semaphores
|
||||
|
|
@ -397,6 +471,9 @@ class APIServer(ABC):
|
|||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
|
||||
|
||||
Automatically injects reasoning config if configured. Pass `skip_reasoning=True`
|
||||
to bypass reasoning injection for this specific call.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
|
|
@ -413,6 +490,10 @@ class APIServer(ABC):
|
|||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
|
||||
# Inject reasoning config if enabled (can be skipped via skip_reasoning=True)
|
||||
kwargs = self._inject_reasoning_kwargs(kwargs)
|
||||
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
|
|
@ -463,6 +544,9 @@ class APIServer(ABC):
|
|||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
|
||||
|
||||
Automatically injects reasoning config if configured. Pass `skip_reasoning=True`
|
||||
to bypass reasoning injection for this specific call.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
|
|
@ -480,6 +564,10 @@ class APIServer(ABC):
|
|||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
|
||||
# Inject reasoning config if enabled (can be skipped via skip_reasoning=True)
|
||||
kwargs = self._inject_reasoning_kwargs(kwargs)
|
||||
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import inspect
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
|
@ -119,9 +119,15 @@ class ServerManager:
|
|||
api_key="x",
|
||||
)
|
||||
)
|
||||
self.servers = [server_class(config) for config in openai_configs]
|
||||
self.servers = [
|
||||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in openai_configs
|
||||
]
|
||||
elif not slurm:
|
||||
self.servers = [server_class(config) for config in configs]
|
||||
self.servers = [
|
||||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in configs
|
||||
]
|
||||
else:
|
||||
nodelist = (
|
||||
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
|
||||
|
|
@ -134,7 +140,10 @@ class ServerManager:
|
|||
"Not enough nodes to distribute to, assuming single node"
|
||||
" and you've setup your sglang appropriately."
|
||||
)
|
||||
self.servers = [server_class(config) for config in configs]
|
||||
self.servers = [
|
||||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in configs
|
||||
]
|
||||
return
|
||||
urls = []
|
||||
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
|
||||
|
|
@ -149,67 +158,15 @@ class ServerManager:
|
|||
new_conf = configs[0].model_copy(deep=True)
|
||||
new_conf.base_url = urls[i]
|
||||
new_configs.append(new_conf)
|
||||
self.servers = [server_class(config) for config in new_configs]
|
||||
self.servers = [
|
||||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in new_configs
|
||||
]
|
||||
|
||||
async def update_weight(self, weight: float):
|
||||
for server in self.servers:
|
||||
await server.update_weight(weight)
|
||||
|
||||
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
|
||||
"""Get the base_url from a server's config."""
|
||||
if not self.servers:
|
||||
return None
|
||||
server = self.servers[server_idx]
|
||||
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
||||
return server.config.base_url
|
||||
return None
|
||||
|
||||
def _inject_reasoning_extra_body(
|
||||
self, kwargs: Dict[str, Any], server_idx: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Inject reasoning extra_body into kwargs if reasoning is configured.
|
||||
|
||||
This method handles the differences between OpenAI and other providers:
|
||||
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
|
||||
and uses max_completion_tokens instead of max_tokens
|
||||
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict to modify
|
||||
server_idx: Index of the server to use for base_url detection
|
||||
|
||||
Returns:
|
||||
Modified kwargs dict with extra_body injected if reasoning is active
|
||||
"""
|
||||
if self.reasoning_config is None or not self.reasoning_config.is_active():
|
||||
return kwargs
|
||||
|
||||
# Get the base_url to determine provider type
|
||||
base_url = self._get_server_base_url(server_idx)
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
# Build the extra_body for reasoning
|
||||
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
|
||||
|
||||
if reasoning_extra_body:
|
||||
# Merge with any existing extra_body in kwargs
|
||||
existing_extra_body = kwargs.get("extra_body", {}) or {}
|
||||
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
|
||||
|
||||
# OpenAI reasoning models have specific requirements
|
||||
if is_openai_official:
|
||||
# OpenAI reasoning models require temperature=1.0 (or unset)
|
||||
# Override any temperature setting
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
|
||||
# Convert if max_tokens is set
|
||||
if "max_tokens" in kwargs and kwargs["max_tokens"]:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
|
||||
async def wait_for_sem(self, is_training: bool):
|
||||
"""
|
||||
Wait for a server to be available. This is used to prevent the client from
|
||||
|
|
@ -246,7 +203,8 @@ class ServerManager:
|
|||
"""
|
||||
Route chat completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
Reasoning config injection is handled by the individual servers.
|
||||
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
|
|
@ -281,16 +239,14 @@ class ServerManager:
|
|||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].chat_completion(**kwargs)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Route completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
Reasoning config injection is handled by the individual servers.
|
||||
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
|
|
@ -323,9 +279,6 @@ class ServerManager:
|
|||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
|
|
@ -335,7 +288,8 @@ class ServerManager:
|
|||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
Note: Reasoning config is NOT injected here - this method is for extracting
|
||||
raw token-level data for training, not for generating reasoned responses.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
|
|
@ -374,9 +328,6 @@ class ServerManager:
|
|||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@ from pydantic_cli import FailedExecutionException
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class SGLangServer(APIServer):
|
||||
|
|
@ -17,14 +21,18 @@ class SGLangServer(APIServer):
|
|||
SGLang server handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
reasoning_config: ReasoningConfig = None,
|
||||
):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ from openai.types.chat.chat_completion import (
|
|||
from openai.types.completion import Completion, CompletionChoice
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class TrlVllmServer(APIServer):
|
||||
|
|
@ -24,10 +28,14 @@ class TrlVllmServer(APIServer):
|
|||
A server that interfaces with trl's vLLM server.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
reasoning_config: ReasoningConfig = None,
|
||||
):
|
||||
self.config = config
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,11 @@ from pydantic_cli import FailedExecutionException
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class VLLMServer(APIServer):
|
||||
|
|
@ -20,14 +24,18 @@ class VLLMServer(APIServer):
|
|||
VLLM server handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
reasoning_config: ReasoningConfig = None,
|
||||
):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -200,22 +200,25 @@ def test_reasoning_config_invalid_effort():
|
|||
print("✓ Invalid effort raises ValueError")
|
||||
|
||||
|
||||
def test_reasoning_config_invalid_max_tokens():
|
||||
"""Test that invalid max_tokens raises ValueError."""
|
||||
# Too low
|
||||
try:
|
||||
ReasoningConfig(max_tokens=500) # Should raise
|
||||
assert False, "Should have raised ValueError for too low"
|
||||
except ValueError as e:
|
||||
assert "must be between 1024 and 32000" in str(e)
|
||||
|
||||
# Too high
|
||||
try:
|
||||
ReasoningConfig(max_tokens=50000) # Should raise
|
||||
assert False, "Should have raised ValueError for too high"
|
||||
except ValueError as e:
|
||||
assert "must be between 1024 and 32000" in str(e)
|
||||
print("✓ Invalid max_tokens raises ValueError")
|
||||
def test_reasoning_config_max_tokens_no_validation():
|
||||
"""Test that max_tokens accepts any value (no range validation).
|
||||
|
||||
Provider limits vary and may change over time:
|
||||
- OpenRouter currently caps Anthropic at 1024-32000
|
||||
- Native Anthropic API supports up to 128k extended thinking
|
||||
We don't enforce limits here to allow flexibility.
|
||||
"""
|
||||
# Low values should work
|
||||
config_low = ReasoningConfig(max_tokens=500)
|
||||
assert config_low.max_tokens == 500
|
||||
assert config_low.enabled # Auto-enabled
|
||||
|
||||
# High values should work (e.g., for native Anthropic 128k thinking)
|
||||
config_high = ReasoningConfig(max_tokens=128000)
|
||||
assert config_high.max_tokens == 128000
|
||||
assert config_high.enabled
|
||||
|
||||
print("✓ max_tokens accepts any value (no range validation)")
|
||||
|
||||
|
||||
def test_hermes_prompts_defined():
|
||||
|
|
@ -855,7 +858,7 @@ def run_unit_tests():
|
|||
test_reasoning_config_full()
|
||||
test_reasoning_config_effort_mapping()
|
||||
test_reasoning_config_invalid_effort()
|
||||
test_reasoning_config_invalid_max_tokens()
|
||||
test_reasoning_config_max_tokens_no_validation()
|
||||
test_hermes_prompts_defined()
|
||||
|
||||
# ServerManager integration tests (no API calls)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue