mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
Merge branch 'add_reasoning_handling_draft' of https://github.com/NousResearch/atropos into add_reasoning_handling_draft
This commit is contained in:
commit
127a925471
2 changed files with 19 additions and 19 deletions
|
|
@ -169,45 +169,45 @@ class ServerManager:
|
|||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Inject reasoning extra_body into kwargs if reasoning is configured.
|
||||
|
||||
|
||||
This method handles the differences between OpenAI and other providers:
|
||||
- OpenAI: Uses {"reasoning_effort": "..."} at top level, requires temperature=1.0,
|
||||
and uses max_completion_tokens instead of max_tokens
|
||||
- Others: Uses {"reasoning": {"enabled": True, "effort": "...", "max_tokens": ...}}
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict to modify
|
||||
server_idx: Index of the server to use for base_url detection
|
||||
|
||||
|
||||
Returns:
|
||||
Modified kwargs dict with extra_body injected if reasoning is active
|
||||
"""
|
||||
if self.reasoning_config is None or not self.reasoning_config.is_active():
|
||||
return kwargs
|
||||
|
||||
|
||||
# Get the base_url to determine provider type
|
||||
base_url = self._get_server_base_url(server_idx)
|
||||
is_openai_official = base_url and "api.openai.com" in base_url
|
||||
|
||||
|
||||
# Build the extra_body for reasoning
|
||||
reasoning_extra_body = self.reasoning_config.build_extra_body(base_url)
|
||||
|
||||
|
||||
if reasoning_extra_body:
|
||||
# Merge with any existing extra_body in kwargs
|
||||
existing_extra_body = kwargs.get("extra_body", {}) or {}
|
||||
kwargs["extra_body"] = {**existing_extra_body, **reasoning_extra_body}
|
||||
|
||||
|
||||
# OpenAI reasoning models have specific requirements
|
||||
if is_openai_official:
|
||||
# OpenAI reasoning models require temperature=1.0 (or unset)
|
||||
# Override any temperature setting
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
|
||||
# OpenAI reasoning models use max_completion_tokens instead of max_tokens
|
||||
# Convert if max_tokens is set
|
||||
if "max_tokens" in kwargs and kwargs["max_tokens"]:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
|
||||
return kwargs
|
||||
|
||||
async def wait_for_sem(self, is_training: bool):
|
||||
|
|
@ -245,7 +245,7 @@ class ServerManager:
|
|||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Route chat completion to the most available server.
|
||||
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
|
|
@ -280,16 +280,16 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
|
||||
return await self.servers[most_available_server].chat_completion(**kwargs)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Route completion to the most available server.
|
||||
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
|
|
@ -322,10 +322,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
|
|
@ -334,7 +334,7 @@ class ServerManager:
|
|||
"""
|
||||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
|
||||
|
||||
Automatically injects reasoning extra_body if reasoning_config is active.
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
|
|
@ -373,10 +373,10 @@ class ServerManager:
|
|||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
|
||||
# Inject reasoning extra_body if configured
|
||||
kwargs = self._inject_reasoning_extra_body(kwargs, most_available_server)
|
||||
|
||||
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue