Merge branch 'add_reasoning_handling_draft' of https://github.com/NousResearch/atropos into add_reasoning_handling_draft

This commit is contained in:
teknium 2025-12-30 11:59:46 +00:00
commit 127a925471
2 changed files with 19 additions and 19 deletions

View file

@ -169,45 +169,45 @@ class ServerManager:
) -> Dict[str, Any]:
"""
Inject reasoning extra_body into kwargs if reasoning is configured.
This method handles the differences between OpenAI and other providers:
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
and uses max_completion_tokens instead of max_tokens
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
Args:
kwargs: The kwargs dict to modify
server_idx: Index of the server to use for base_url detection
Returns:
Modified kwargs dict with extra_body injected if reasoning is active
"""
if self.reasoning_config is None or not self.reasoning_config.is_active():
return kwargs
# Get the base_url to determine provider type
base_url = self._get_server_base_url(server_idx)
is_openai_official = base_url and "api.openai.com" in base_url
# Build the extra_body for reasoning
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
if reasoning_extra_body:
# Merge with any existing extra_body in kwargs
existing_extra_body = kwargs.get("extra_body", {}) or {}
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
# OpenAI reasoning models have specific requirements
if is_openai_official:
# OpenAI reasoning models require temperature=1.0 (or unset)
# Override any temperature setting
kwargs["temperature"] = 1.0
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
# Convert if max_tokens is set
if "max_tokens" in kwargs and kwargs["max_tokens"]:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
async def wait_for_sem(self, is_training: bool):
@ -245,7 +245,7 @@ class ServerManager:
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Route chat completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
@ -280,16 +280,16 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
"""
Route completion to the most available server.
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
@ -322,10 +322,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
@ -334,7 +334,7 @@ class ServerManager:
"""
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Automatically injects reasoning extra_body if reasoning_config is active.
"""
n = kwargs.get("n", 1)
@ -373,10 +373,10 @@ class ServerManager:
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
# Inject reasoning extra_body if configured
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
return await self.servers[most_available_server].tokens_and_logprobs_completion(
**kwargs
)