Add support for reasoning models and their variety of providers/endpoints

This commit is contained in:
teknium 2025-12-30 00:23:00 +00:00
parent 1c306d3b17
commit 62fa51240c
6 changed files with 1551 additions and 16 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import inspect
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
@ -13,6 +13,7 @@ from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
ReasoningConfig,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
@ -46,8 +47,10 @@ class ServerManager:
slurm=False,
testing=False,
max_n_completions=8,
reasoning_config: Optional[ReasoningConfig] = None,
):
self.max_n_completions = max_n_completions
self.reasoning_config = reasoning_config
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
@ -152,6 +155,61 @@ class ServerManager:
for server in self.servers:
await server.update_weight(weight)
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
"""Get the base_url from a server's config."""
if not self.servers:
return None
server = self.servers[server_idx]
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _inject_reasoning_extra_body(
self, kwargs: Dict[str, Any], server_idx: int = 0
) -> Dict[str, Any]:
"""
Inject reasoning extra_body into kwargs if reasoning is configured.
This method handles the differences between OpenAI and other providers:
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
and uses max_completion_tokens instead of max_tokens
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
Args:
kwargs: The kwargs dict to modify
server_idx: Index of the server to use for base_url detection
Returns:
Modified kwargs dict with extra_body injected if reasoning is active
"""
if self.reasoning_config is None or not self.reasoning_config.is_active():
return kwargs
# Get the base_url to determine provider type
base_url = self._get_server_base_url(server_idx)
is_openai_official = base_url and "api.openai.com" in base_url
# Build the extra_body for reasoning
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
if reasoning_extra_body:
# Merge with any existing extra_body in kwargs
existing_extra_body = kwargs.get("extra_body", {}) or {}
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
# OpenAI reasoning models have specific requirements
if is_openai_official:
# OpenAI reasoning models require temperature=1.0 (or unset)
# Override any temperature setting
kwargs["temperature"] = 1.0
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
# Convert if max_tokens is set
if "max_tokens" in kwargs and kwargs["max_tokens"]:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
async def wait_for_sem(self, is_training: bool):
"""
Wait for a server to be available. This is used to prevent the client from
@ -185,6 +243,11 @@ class ServerManager:
sem_vals = get_available_slots()
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Route chat completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
@ -217,9 +280,18 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
"""
Route completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
@ -250,6 +322,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
@ -258,6 +334,8 @@ class ServerManager:
"""
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
@ -295,6 +373,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].tokens_and_logprobs_completion(
**kwargs
)