Add support for reasoning models and their variety of providers/endpoints

This commit is contained in:
teknium 2025-12-30 00:23:00 +00:00
parent 1c306d3b17
commit 62fa51240c
6 changed files with 1551 additions and 16 deletions

View file

@ -3,7 +3,8 @@ import collections
import time
from abc import ABC, abstractmethod
from asyncio import exceptions
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
import numpy as np
from openai.types.chat.chat_completion import ChatCompletion
@ -11,6 +12,128 @@ from openai.types.completion import Completion
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
# Valid reasoning effort levels
VALID_REASONING_EFFORTS = {"none", "minimal", "low", "medium", "high", "xhigh"}
@dataclass
class ReasoningConfig:
"""
Configuration for reasoning/thinking model support.
This config is used by ServerManager to automatically inject the appropriate
extra_body parameters into API requests based on the provider (OpenAI vs others).
Attributes:
enabled: Whether reasoning mode is enabled. Auto-set to True if effort or
max_tokens are specified.
effort: Reasoning effort level. One of: "none", "minimal", "low", "medium",
"high", "xhigh". Default None (not specified).
max_tokens: Maximum tokens for reasoning (1024-32000). Default None.
"""
enabled: bool = False
effort: Optional[str] = None
max_tokens: Optional[int] = None
def __post_init__(self):
"""Validate and auto-enable if effort or max_tokens are set."""
# Validate effort if provided
if self.effort is not None and self.effort not in VALID_REASONING_EFFORTS:
raise ValueError(
f"Invalid reasoning_effort: {self.effort}. "
f"Must be one of: {VALID_REASONING_EFFORTS}"
)
# Validate max_tokens range if provided
if self.max_tokens is not None:
if self.max_tokens < 1024 or self.max_tokens > 32000:
raise ValueError(
f"max_reasoning_tokens must be between 1024 and 32000, "
f"got {self.max_tokens}"
)
# Auto-enable if effort or max_tokens are specified
if self.effort is not None or self.max_tokens is not None:
self.enabled = True
def is_active(self) -> bool:
"""Check if reasoning is active (enabled with any settings)."""
return self.enabled
def build_extra_body(self, base_url: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Build the extra_body dict for API requests based on provider.
Args:
base_url: The API base URL, used to detect OpenAI official endpoint.
Returns:
Dict to merge into extra_body, or None if reasoning not active.
Note:
OpenRouter only allows ONE of effort or max_tokens, not both.
When both are specified, effort takes priority.
"""
if not self.is_active():
return None
# Detect if using official OpenAI endpoint
is_openai_official = base_url and "api.openai.com" in base_url
if is_openai_official:
# OpenAI only accepts reasoning_effort at top level, not nested reasoning object
# They also don't support max_tokens for reasoning
effort = self.effort if self.effort else "medium"
# Map our extended effort levels to OpenAI's supported values
openai_effort_map = {
"none": "low", # OpenAI doesn't have "none", use low
"minimal": "low", # OpenAI doesn't have "minimal", use low
"low": "low",
"medium": "medium",
"high": "high",
"xhigh": "high", # OpenAI doesn't have "xhigh", use high
}
return {"reasoning_effort": openai_effort_map.get(effort, "medium")}
else:
# Standard format for OpenRouter, Nebius, Nous Portal, etc.
# Note: OpenRouter only allows ONE of effort or max_tokens, not both.
# When both are specified, effort takes priority.
reasoning = {"enabled": True}
if self.effort is not None:
reasoning["effort"] = self.effort
elif self.max_tokens is not None:
# Only add max_tokens if effort is not specified
reasoning["max_tokens"] = self.max_tokens
return {"reasoning": reasoning}
@classmethod
def from_env_config(cls, env_config) -> "ReasoningConfig":
"""
Create a ReasoningConfig from a BaseEnvConfig.
This is used by BaseEnv to convert environment config settings
into the reasoning configuration used by ServerManager.
Args:
env_config: A BaseEnvConfig (or subclass) instance with reasoning fields.
Returns:
A ReasoningConfig instance configured based on the env_config.
"""
# Get reasoning settings from env config
thinking_mode = getattr(env_config, "thinking_mode", False)
reasoning_effort = getattr(env_config, "reasoning_effort", None)
max_reasoning_tokens = getattr(env_config, "max_reasoning_tokens", None)
# Determine if enabled: explicitly True, or implied by effort/max_tokens
enabled = thinking_mode or reasoning_effort is not None or max_reasoning_tokens is not None
return cls(
enabled=enabled,
effort=reasoning_effort,
max_tokens=max_reasoning_tokens,
)
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
def __init__(self, value: int):

View file

@ -2,7 +2,7 @@ import asyncio
import inspect
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
@ -13,6 +13,7 @@ from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
ReasoningConfig,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
@ -46,8 +47,10 @@ class ServerManager:
slurm=False,
testing=False,
max_n_completions=8,
reasoning_config: Optional[ReasoningConfig] = None,
):
self.max_n_completions = max_n_completions
self.reasoning_config = reasoning_config
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
@ -152,6 +155,61 @@ class ServerManager:
for server in self.servers:
await server.update_weight(weight)
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
"""Get the base_url from a server's config."""
if not self.servers:
return None
server = self.servers[server_idx]
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _inject_reasoning_extra_body(
self, kwargs: Dict[str, Any], server_idx: int = 0
) -> Dict[str, Any]:
"""
Inject reasoning extra_body into kwargs if reasoning is configured.
This method handles the differences between OpenAI and other providers:
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
and uses max_completion_tokens instead of max_tokens
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
Args:
kwargs: The kwargs dict to modify
server_idx: Index of the server to use for base_url detection
Returns:
Modified kwargs dict with extra_body injected if reasoning is active
"""
if self.reasoning_config is None or not self.reasoning_config.is_active():
return kwargs
# Get the base_url to determine provider type
base_url = self._get_server_base_url(server_idx)
is_openai_official = base_url and "api.openai.com" in base_url
# Build the extra_body for reasoning
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
if reasoning_extra_body:
# Merge with any existing extra_body in kwargs
existing_extra_body = kwargs.get("extra_body", {}) or {}
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
# OpenAI reasoning models have specific requirements
if is_openai_official:
# OpenAI reasoning models require temperature=1.0 (or unset)
# Override any temperature setting
kwargs["temperature"] = 1.0
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
# Convert if max_tokens is set
if "max_tokens" in kwargs and kwargs["max_tokens"]:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
async def wait_for_sem(self, is_training: bool):
"""
Wait for a server to be available. This is used to prevent the client from
@ -185,6 +243,11 @@ class ServerManager:
sem_vals = get_available_slots()
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Route chat completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
@ -217,9 +280,18 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
"""
Route completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
@ -250,6 +322,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
@ -258,6 +334,8 @@ class ServerManager:
"""
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
@ -295,6 +373,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].tokens_and_logprobs_completion(
**kwargs
)