mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
Add support for reasoning models and their variety of providers/endpoints
This commit is contained in:
parent
1c306d3b17
commit
62fa51240c
6 changed files with 1551 additions and 16 deletions
|
|
@ -3,7 +3,8 @@ import collections
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import exceptions
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
|
@ -11,6 +12,128 @@ from openai.types.completion import Completion
|
|||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
# Valid reasoning effort levels
|
||||
VALID_REASONING_EFFORTS = {"none", "minimal", "low", "medium", "high", "xhigh"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReasoningConfig:
|
||||
"""
|
||||
Configuration for reasoning/thinking model support.
|
||||
|
||||
This config is used by ServerManager to automatically inject the appropriate
|
||||
extra_body parameters into API requests based on the provider (OpenAI vs others).
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reasoning mode is enabled. Auto-set to True if effort or
|
||||
max_tokens are specified.
|
||||
effort: Reasoning effort level. One of: "none", "minimal", "low", "medium",
|
||||
"high", "xhigh". Default None (not specified).
|
||||
max_tokens: Maximum tokens for reasoning (1024-32000). Default None.
|
||||
"""
|
||||
enabled: bool = False
|
||||
effort: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate and auto-enable if effort or max_tokens are set."""
|
||||
# Validate effort if provided
|
||||
if self.effort is not None and self.effort not in VALID_REASONING_EFFORTS:
|
||||
raise ValueError(
|
||||
f"Invalid reasoning_effort: {self.effort}. "
|
||||
f"Must be one of: {VALID_REASONING_EFFORTS}"
|
||||
)
|
||||
|
||||
# Validate max_tokens range if provided
|
||||
if self.max_tokens is not None:
|
||||
if self.max_tokens < 1024 or self.max_tokens > 32000:
|
||||
raise ValueError(
|
||||
f"max_reasoning_tokens must be between 1024 and 32000, "
|
||||
f"got {self.max_tokens}"
|
||||
)
|
||||
|
||||
# Auto-enable if effort or max_tokens are specified
|
||||
if self.effort is not None or self.max_tokens is not None:
|
||||
self.enabled = True
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if reasoning is active (enabled with any settings)."""
|
||||
return self.enabled
|
||||
|
||||
def build_extra_body(self, base_url: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Build the extra_body dict for API requests based on provider.
|
||||
|
||||
Args:
|
||||
base_url: The API base URL, used to detect OpenAI official endpoint.
|
||||
|
||||
Returns:
|
||||
Dict to merge into extra_body, or None if reasoning not active.
|
||||
|
||||
Note:
|
||||
OpenRouter only allows ONE of effort or max_tokens, not both.
|
||||
When both are specified, effort takes priority.
|
||||
"""
|
||||
if not self.is_active():
|
||||
return None
|
||||
|
||||
# Detect if using official OpenAI endpoint
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
if is_openai_official:
|
||||
# OpenAI only accepts reasoning_effort at top level, not nested reasoning object
|
||||
# They also don't support max_tokens for reasoning
|
||||
effort = self.effort if self.effort else "medium"
|
||||
# Map our extended effort levels to OpenAI's supported values
|
||||
openai_effort_map = {
|
||||
"none": "low", # OpenAI doesn't have "none", use low
|
||||
"minimal": "low", # OpenAI doesn't have "minimal", use low
|
||||
"low": "low",
|
||||
"medium": "medium",
|
||||
"high": "high",
|
||||
"xhigh": "high", # OpenAI doesn't have "xhigh", use high
|
||||
}
|
||||
return {"reasoning_effort": openai_effort_map.get(effort, "medium")}
|
||||
else:
|
||||
# Standard format for OpenRouter, Nebius, Nous Portal, etc.
|
||||
# Note: OpenRouter only allows ONE of effort or max_tokens, not both.
|
||||
# When both are specified, effort takes priority.
|
||||
reasoning = {"enabled": True}
|
||||
if self.effort is not None:
|
||||
reasoning["effort"] = self.effort
|
||||
elif self.max_tokens is not None:
|
||||
# Only add max_tokens if effort is not specified
|
||||
reasoning["max_tokens"] = self.max_tokens
|
||||
return {"reasoning": reasoning}
|
||||
|
||||
@classmethod
|
||||
def from_env_config(cls, env_config) -> "ReasoningConfig":
|
||||
"""
|
||||
Create a ReasoningConfig from a BaseEnvConfig.
|
||||
|
||||
This is used by BaseEnv to convert environment config settings
|
||||
into the reasoning configuration used by ServerManager.
|
||||
|
||||
Args:
|
||||
env_config: A BaseEnvConfig (or subclass) instance with reasoning fields.
|
||||
|
||||
Returns:
|
||||
A ReasoningConfig instance configured based on the env_config.
|
||||
"""
|
||||
# Get reasoning settings from env config
|
||||
thinking_mode = getattr(env_config, "thinking_mode", False)
|
||||
reasoning_effort = getattr(env_config, "reasoning_effort", None)
|
||||
max_reasoning_tokens = getattr(env_config, "max_reasoning_tokens", None)
|
||||
|
||||
# Determine if enabled: explicitly True, or implied by effort/max_tokens
|
||||
enabled = thinking_mode or reasoning_effort is not None or max_reasoning_tokens is not None
|
||||
|
||||
return cls(
|
||||
enabled=enabled,
|
||||
effort=reasoning_effort,
|
||||
max_tokens=max_reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||
def __init__(self, value: int):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import inspect
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, List, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
|
@ -13,6 +13,7 @@ from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
|
|
@ -46,8 +47,10 @@ class ServerManager:
|
|||
slurm=False,
|
||||
testing=False,
|
||||
max_n_completions=8,
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
self.max_n_completions = max_n_completions
|
||||
self.reasoning_config = reasoning_config
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
|
|
@ -152,6 +155,61 @@ class ServerManager:
|
|||
for server in self.servers:
|
||||
await server.update_weight(weight)
|
||||
|
||||
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
|
||||
"""Get the base_url from a server's config."""
|
||||
if not self.servers:
|
||||
return None
|
||||
server = self.servers[server_idx]
|
||||
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
||||
return server.config.base_url
|
||||
return None
|
||||
|
||||
def _inject_reasoning_extra_body(
|
||||
self, kwargs: Dict[str, Any], server_idx: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Inject reasoning extra_body into kwargs if reasoning is configured.
|
||||
|
||||
This method handles the differences between OpenAI and other providers:
|
||||
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
|
||||
and uses max_completion_tokens instead of max_tokens
|
||||
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict to modify
|
||||
server_idx: Index of the server to use for base_url detection
|
||||
|
||||
Returns:
|
||||
Modified kwargs dict with extra_body injected if reasoning is active
|
||||
"""
|
||||
if self.reasoning_config is None or not self.reasoning_config.is_active():
|
||||
return kwargs
|
||||
|
||||
# Get the base_url to determine provider type
|
||||
base_url = self._get_server_base_url(server_idx)
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
# Build the extra_body for reasoning
|
||||
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
|
||||
|
||||
if reasoning_extra_body:
|
||||
# Merge with any existing extra_body in kwargs
|
||||
existing_extra_body = kwargs.get("extra_body", {}) or {}
|
||||
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
|
||||
|
||||
# OpenAI reasoning models have specific requirements
|
||||
if is_openai_official:
|
||||
# OpenAI reasoning models require temperature=1.0 (or unset)
|
||||
# Override any temperature setting
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
|
||||
# Convert if max_tokens is set
|
||||
if "max_tokens" in kwargs and kwargs["max_tokens"]:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
|
||||
async def wait_for_sem(self, is_training: bool):
|
||||
"""
|
||||
Wait for a server to be available. This is used to prevent the client from
|
||||
|
|
@ -185,6 +243,11 @@ class ServerManager:
|
|||
sem_vals = get_available_slots()
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Route chat completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
|
|
@ -217,9 +280,18 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].chat_completion(**kwargs)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Route completion to the most available server.
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
|
|
@ -250,6 +322,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
|
|
@ -258,6 +334,8 @@ class ServerManager:
|
|||
"""
|
||||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
|
|
@ -295,6 +373,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue