diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 8f6e7545..fce40f80 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -8,7 +8,11 @@ from openai.types.completion import Completion from pydantic_cli import FailedExecutionException from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE -from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + ReasoningConfig, +) class OpenAIServer(APIServer): @@ -16,13 +20,17 @@ class OpenAIServer(APIServer): OpenAI server handling. """ - def __init__(self, config: APIServerConfig): + def __init__( + self, + config: APIServerConfig, + reasoning_config: ReasoningConfig = None, + ): self.openai = openai.AsyncClient( api_key=config.api_key, base_url=config.base_url, timeout=config.timeout, ) - super().__init__(config) + super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): while True: diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index a25222fe..ed2e73ef 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -29,7 +29,9 @@ class ReasoningConfig: max_tokens are specified. effort: Reasoning effort level. One of: "none", "minimal", "low", "medium", "high", "xhigh". Default None (not specified). - max_tokens: Maximum tokens for reasoning (1024-32000). Default None. + max_tokens: Maximum tokens for reasoning. No validation enforced - provider + limits vary (e.g., OpenRouter currently caps Anthropic at 1024-32000, + but native Anthropic supports up to 128k). Default None. """ enabled: bool = False @@ -45,13 +47,10 @@ class ReasoningConfig: f"Must be one of: {VALID_REASONING_EFFORTS}" ) - # Validate max_tokens range if provided - if self.max_tokens is not None: - if self.max_tokens < 1024 or self.max_tokens > 32000: - raise ValueError( - f"max_reasoning_tokens must be between 1024 and 32000, " - f"got {self.max_tokens}" - ) + # Note: As of 2024, OpenRouter caps Anthropic reasoning tokens at 1024-32000 + # See: https://openrouter.ai/docs/guides/best-practices/reasoning-tokens + # However, we don't enforce this limit here since providers may extend ranges + # (e.g., Anthropic's latest models support up to 128k extended thinking) # Auto-enable if effort or max_tokens are specified if self.effort is not None or self.max_tokens is not None: @@ -61,21 +60,38 @@ class ReasoningConfig: """Check if reasoning is active (enabled with any settings).""" return self.enabled + # Mapping from effort levels to approximate max_tokens values + # Based on OpenRouter's effort-to-budget_tokens formula percentages: + # https://openrouter.ai/docs/guides/best-practices/reasoning-tokens + # Calculated as percentage of 32k base: none=min, minimal=10%, low=20%, + # medium=50%, high=80%, xhigh=95% + EFFORT_TO_MAX_TOKENS = { + "none": 1024, # Minimum/disabled + "minimal": 3200, # ~10% of 32k + "low": 6400, # ~20% of 32k + "medium": 16000, # ~50% of 32k + "high": 25600, # ~80% of 32k + "xhigh": 30400, # ~95% of 32k + } + def build_extra_body( - self, base_url: Optional[str] = None + self, base_url: Optional[str] = None, use_max_tokens: bool = False ) -> Optional[Dict[str, Any]]: """ Build the extra_body dict for API requests based on provider. Args: base_url: The API base URL, used to detect OpenAI official endpoint. + use_max_tokens: If True, convert effort levels to max_tokens values + instead of passing effort strings. Useful for providers + that only support token-based reasoning limits. Returns: Dict to merge into extra_body, or None if reasoning not active. Note: OpenRouter only allows ONE of effort or max_tokens, not both. - When both are specified, effort takes priority. + When both are specified, effort takes priority (unless use_max_tokens=True). """ if not self.is_active(): return None @@ -99,14 +115,20 @@ class ReasoningConfig: return {"reasoning_effort": openai_effort_map.get(effort, "medium")} else: # Standard format for OpenRouter, Nebius, Nous Portal, etc. - # Note: OpenRouter only allows ONE of effort or max_tokens, not both. - # When both are specified, effort takes priority. reasoning = {"enabled": True} - if self.effort is not None: + + # If use_max_tokens is True, convert effort to max_tokens + if use_max_tokens and self.effort is not None: + reasoning["max_tokens"] = self.EFFORT_TO_MAX_TOKENS.get( + self.effort, 8192 + ) + elif self.effort is not None: + # Pass effort string directly (provider may or may not support it) reasoning["effort"] = self.effort elif self.max_tokens is not None: - # Only add max_tokens if effort is not specified + # Use explicit max_tokens if provided reasoning["max_tokens"] = self.max_tokens + return {"reasoning": reasoning} @classmethod @@ -263,8 +285,13 @@ class APIServer(ABC): Abstract class for API servers. """ - def __init__(self, config: APIServerConfig): + def __init__( + self, + config: APIServerConfig, + reasoning_config: Optional[ReasoningConfig] = None, + ): self.config = config + self.reasoning_config = reasoning_config self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once) self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval) self.server_healthy = True @@ -276,6 +303,53 @@ class APIServer(ABC): self.check_task = None self.initialized = False + def _inject_reasoning_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Inject reasoning configuration into kwargs if reasoning is enabled. + + This method can be overridden by subclasses to handle implementation-specific + quirks for different server types (vLLM, SGLang, OpenAI, etc.). + + The caller can pass `skip_reasoning=True` in kwargs to bypass injection. + + Args: + kwargs: The kwargs dict to potentially modify + + Returns: + Modified kwargs dict with reasoning config injected (if applicable) + """ + # Check if caller explicitly wants to skip reasoning injection + skip_reasoning = kwargs.pop("skip_reasoning", False) + if skip_reasoning: + return kwargs + + # Check if reasoning is configured and active + if self.reasoning_config is None or not self.reasoning_config.is_active(): + return kwargs + + # Get base_url to determine provider type + base_url = getattr(self.config, "base_url", None) + is_openai_official = base_url and "api.openai.com" in base_url + + # Build the extra_body for reasoning + reasoning_extra_body = self.reasoning_config.build_extra_body(base_url) + + if reasoning_extra_body: + # Merge with any existing extra_body in kwargs + existing_extra_body = kwargs.get("extra_body", {}) or {} + kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body} + + # OpenAI reasoning models have specific requirements + if is_openai_official: + # OpenAI reasoning models require temperature=1.0 (or unset) + kwargs["temperature"] = 1.0 + + # OpenAI reasoning models use max_completion_tokens instead of max_tokens + if "max_tokens" in kwargs and kwargs["max_tokens"]: + kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") + + return kwargs + async def update_weight(self, weight: float) -> None: """ Update the weight of the semaphores @@ -397,6 +471,9 @@ class APIServer(ABC): async def chat_completion(self, **kwargs) -> ChatCompletion: """ Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper. + + Automatically injects reasoning config if configured. Pass `skip_reasoning=True` + to bypass reasoning injection for this specific call. """ if not self.initialized: if self.config.health_check: @@ -413,6 +490,10 @@ class APIServer(ABC): self.initialized = True kwargs["model"] = self.config.model_name split = kwargs.pop("split", "train") + + # Inject reasoning config if enabled (can be skipped via skip_reasoning=True) + kwargs = self._inject_reasoning_kwargs(kwargs) + stat_dict = {} stat_dict["attempts"] = 0 if split == "train": @@ -463,6 +544,9 @@ class APIServer(ABC): async def completion(self, **kwargs) -> Completion: """ Completion handler, waits for the server to be healthy and then calls the completion wrapper. + + Automatically injects reasoning config if configured. Pass `skip_reasoning=True` + to bypass reasoning injection for this specific call. """ if not self.initialized: if self.config.health_check: @@ -480,6 +564,10 @@ class APIServer(ABC): self.initialized = True kwargs["model"] = self.config.model_name split = kwargs.pop("split", "train") + + # Inject reasoning config if enabled (can be skipped via skip_reasoning=True) + kwargs = self._inject_reasoning_kwargs(kwargs) + stat_dict = {} stat_dict["attempts"] = 0 if split == "train": diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index c330514b..cc8454d8 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -2,7 +2,7 @@ import asyncio import inspect import os from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, List, Optional, Union from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion @@ -119,9 +119,15 @@ class ServerManager: api_key="x", ) ) - self.servers = [server_class(config) for config in openai_configs] + self.servers = [ + server_class(config, reasoning_config=reasoning_config) + for config in openai_configs + ] elif not slurm: - self.servers = [server_class(config) for config in configs] + self.servers = [ + server_class(config, reasoning_config=reasoning_config) + for config in configs + ] else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') @@ -134,7 +140,10 @@ class ServerManager: "Not enough nodes to distribute to, assuming single node" " and you've setup your sglang appropriately." ) - self.servers = [server_class(config) for config in configs] + self.servers = [ + server_class(config, reasoning_config=reasoning_config) + for config in configs + ] return urls = [] num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES")) @@ -149,67 +158,15 @@ class ServerManager: new_conf = configs[0].model_copy(deep=True) new_conf.base_url = urls[i] new_configs.append(new_conf) - self.servers = [server_class(config) for config in new_configs] + self.servers = [ + server_class(config, reasoning_config=reasoning_config) + for config in new_configs + ] async def update_weight(self, weight: float): for server in self.servers: await server.update_weight(weight) - def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]: - """Get the base_url from a server's config.""" - if not self.servers: - return None - server = self.servers[server_idx] - if hasattr(server, "config") and hasattr(server.config, "base_url"): - return server.config.base_url - return None - - def _inject_reasoning_extra_body( - self, kwargs: Dict[str, Any], server_idx: int = 0 - ) -> Dict[str, Any]: - """ - Inject reasoning extra_body into kwargs if reasoning is configured. - - This method handles the differences between OpenAI and other providers: - - OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0, - and uses max_completion_tokens instead of max_tokens - - Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}} - - Args: - kwargs: The kwargs dict to modify - server_idx: Index of the server to use for base_url detection - - Returns: - Modified kwargs dict with extra_body injected if reasoning is active - """ - if self.reasoning_config is None or not self.reasoning_config.is_active(): - return kwargs - - # Get the base_url to determine provider type - base_url = self._get_server_base_url(server_idx) - is_openai_official = base_url and "api.openai.com" in base_url - - # Build the extra_body for reasoning - reasoning_extra_body = self.reasoning_config.build_extra_body(base_url) - - if reasoning_extra_body: - # Merge with any existing extra_body in kwargs - existing_extra_body = kwargs.get("extra_body", {}) or {} - kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body} - - # OpenAI reasoning models have specific requirements - if is_openai_official: - # OpenAI reasoning models require temperature=1.0 (or unset) - # Override any temperature setting - kwargs["temperature"] = 1.0 - - # OpenAI reasoning models use max_completion_tokens instead of max_tokens - # Convert if max_tokens is set - if "max_tokens" in kwargs and kwargs["max_tokens"]: - kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") - - return kwargs - async def wait_for_sem(self, is_training: bool): """ Wait for a server to be available. This is used to prevent the client from @@ -246,7 +203,8 @@ class ServerManager: """ Route chat completion to the most available server. - Automatically injects reasoning extra_body if reasoning_config is active. + Reasoning config injection is handled by the individual servers. + Pass `skip_reasoning=True` to bypass reasoning injection for this call. """ n = kwargs.get("n", 1) if n > self.max_n_completions: @@ -281,16 +239,14 @@ class ServerManager: server.sem._value if is_train else server.eval_sem._value ) - # Inject reasoning extra_body if configured - kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - return await self.servers[most_available_server].chat_completion(**kwargs) async def completion(self, **kwargs) -> Completion: """ Route completion to the most available server. - Automatically injects reasoning extra_body if reasoning_config is active. + Reasoning config injection is handled by the individual servers. + Pass `skip_reasoning=True` to bypass reasoning injection for this call. """ n = kwargs.get("n", 1) if n > self.max_n_completions: @@ -323,9 +279,6 @@ class ServerManager: server.sem._value if is_train else server.eval_sem._value ) - # Inject reasoning extra_body if configured - kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - return await self.servers[most_available_server].completion(**kwargs) async def tokens_and_logprobs_completion( @@ -335,7 +288,8 @@ class ServerManager: Get tokens and logprobs from completion. Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons). - Automatically injects reasoning extra_body if reasoning_config is active. + Note: Reasoning config is NOT injected here - this method is for extracting + raw token-level data for training, not for generating reasoned responses. """ n = kwargs.get("n", 1) if n > self.max_n_completions: @@ -374,9 +328,6 @@ class ServerManager: server.sem._value if is_train else server.eval_sem._value ) - # Inject reasoning extra_body if configured - kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server) - return await self.servers[most_available_server].tokens_and_logprobs_completion( **kwargs ) diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index cbc3ffd6..19ac3d6e 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -9,7 +9,11 @@ from pydantic_cli import FailedExecutionException from transformers import AutoTokenizer from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE -from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + ReasoningConfig, +) class SGLangServer(APIServer): @@ -17,14 +21,18 @@ class SGLangServer(APIServer): SGLang server handling. """ - def __init__(self, config: APIServerConfig): + def __init__( + self, + config: APIServerConfig, + reasoning_config: ReasoningConfig = None, + ): self.openai = openai.AsyncClient( api_key=config.api_key, base_url=config.base_url, timeout=config.timeout, ) self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) - super().__init__(config) + super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): while True: diff --git a/atroposlib/envs/server_handling/trl_vllm_server.py b/atroposlib/envs/server_handling/trl_vllm_server.py index 81917713..6d42c31b 100644 --- a/atroposlib/envs/server_handling/trl_vllm_server.py +++ b/atroposlib/envs/server_handling/trl_vllm_server.py @@ -16,7 +16,11 @@ from openai.types.chat.chat_completion import ( from openai.types.completion import Completion, CompletionChoice from transformers import AutoTokenizer -from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + ReasoningConfig, +) class TrlVllmServer(APIServer): @@ -24,10 +28,14 @@ class TrlVllmServer(APIServer): A server that interfaces with trl's vLLM server. """ - def __init__(self, config: APIServerConfig): + def __init__( + self, + config: APIServerConfig, + reasoning_config: ReasoningConfig = None, + ): self.config = config self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) - super().__init__(config) + super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): """ diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 8e043cf9..e9a75766 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -12,7 +12,11 @@ from pydantic_cli import FailedExecutionException from transformers import AutoTokenizer from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE -from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + ReasoningConfig, +) class VLLMServer(APIServer): @@ -20,14 +24,18 @@ class VLLMServer(APIServer): VLLM server handling. """ - def __init__(self, config: APIServerConfig): + def __init__( + self, + config: APIServerConfig, + reasoning_config: ReasoningConfig = None, + ): self.openai = openai.AsyncClient( api_key=config.api_key, base_url=config.base_url, timeout=config.timeout, ) self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) - super().__init__(config) + super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): while True: diff --git a/atroposlib/tests/test_reasoning_models.py b/atroposlib/tests/test_reasoning_models.py index 290a47bb..5b808eef 100644 --- a/atroposlib/tests/test_reasoning_models.py +++ b/atroposlib/tests/test_reasoning_models.py @@ -200,22 +200,25 @@ def test_reasoning_config_invalid_effort(): print("✓ Invalid effort raises ValueError") -def test_reasoning_config_invalid_max_tokens(): - """Test that invalid max_tokens raises ValueError.""" - # Too low - try: - ReasoningConfig(max_tokens=500) # Should raise - assert False, "Should have raised ValueError for too low" - except ValueError as e: - assert "must be between 1024 and 32000" in str(e) - - # Too high - try: - ReasoningConfig(max_tokens=50000) # Should raise - assert False, "Should have raised ValueError for too high" - except ValueError as e: - assert "must be between 1024 and 32000" in str(e) - print("✓ Invalid max_tokens raises ValueError") +def test_reasoning_config_max_tokens_no_validation(): + """Test that max_tokens accepts any value (no range validation). + + Provider limits vary and may change over time: + - OpenRouter currently caps Anthropic at 1024-32000 + - Native Anthropic API supports up to 128k extended thinking + We don't enforce limits here to allow flexibility. + """ + # Low values should work + config_low = ReasoningConfig(max_tokens=500) + assert config_low.max_tokens == 500 + assert config_low.enabled # Auto-enabled + + # High values should work (e.g., for native Anthropic 128k thinking) + config_high = ReasoningConfig(max_tokens=128000) + assert config_high.max_tokens == 128000 + assert config_high.enabled + + print("✓ max_tokens accepts any value (no range validation)") def test_hermes_prompts_defined(): @@ -855,7 +858,7 @@ def run_unit_tests(): test_reasoning_config_full() test_reasoning_config_effort_mapping() test_reasoning_config_invalid_effort() - test_reasoning_config_invalid_max_tokens() + test_reasoning_config_max_tokens_no_validation() test_hermes_prompts_defined() # ServerManager integration tests (no API calls)