mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add support for reasoning models and their variety of providers/endpoints
This commit is contained in:
parent
1c306d3b17
commit
62fa51240c
6 changed files with 1551 additions and 16 deletions
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import inspect
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, List, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
|
@ -13,6 +13,7 @@ from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
|
|
@ -46,8 +47,10 @@ class ServerManager:
|
|||
slurm=False,
|
||||
testing=False,
|
||||
max_n_completions=8,
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
self.max_n_completions = max_n_completions
|
||||
self.reasoning_config = reasoning_config
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
|
|
@ -152,6 +155,61 @@ class ServerManager:
|
|||
for server in self.servers:
|
||||
await server.update_weight(weight)
|
||||
|
||||
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
|
||||
"""Get the base_url from a server's config."""
|
||||
if not self.servers:
|
||||
return None
|
||||
server = self.servers[server_idx]
|
||||
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
||||
return server.config.base_url
|
||||
return None
|
||||
|
||||
def _inject_reasoning_extra_body(
|
||||
self, kwargs: Dict[str, Any], server_idx: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Inject reasoning extra_body into kwargs if reasoning is configured.
|
||||
|
||||
This method handles the differences between OpenAI and other providers:
|
||||
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
|
||||
and uses max_completion_tokens instead of max_tokens
|
||||
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict to modify
|
||||
server_idx: Index of the server to use for base_url detection
|
||||
|
||||
Returns:
|
||||
Modified kwargs dict with extra_body injected if reasoning is active
|
||||
"""
|
||||
if self.reasoning_config is None or not self.reasoning_config.is_active():
|
||||
return kwargs
|
||||
|
||||
# Get the base_url to determine provider type
|
||||
base_url = self._get_server_base_url(server_idx)
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
# Build the extra_body for reasoning
|
||||
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
|
||||
|
||||
if reasoning_extra_body:
|
||||
# Merge with any existing extra_body in kwargs
|
||||
existing_extra_body = kwargs.get("extra_body", {}) or {}
|
||||
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
|
||||
|
||||
# OpenAI reasoning models have specific requirements
|
||||
if is_openai_official:
|
||||
# OpenAI reasoning models require temperature=1.0 (or unset)
|
||||
# Override any temperature setting
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
|
||||
# Convert if max_tokens is set
|
||||
if "max_tokens" in kwargs and kwargs["max_tokens"]:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
|
||||
async def wait_for_sem(self, is_training: bool):
|
||||
"""
|
||||
Wait for a server to be available. This is used to prevent the client from
|
||||
|
|
@ -185,6 +243,11 @@ class ServerManager:
|
|||
sem_vals = get_available_slots()
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Route chat completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
|
|
@ -217,9 +280,18 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].chat_completion(**kwargs)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Route completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
|
|
@ -250,6 +322,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
|
|
@ -258,6 +334,8 @@ class ServerManager:
|
|||
"""
|
||||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
|
|
@ -295,6 +373,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue