mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Add max_n_completions parameter to ServerManager for handling multiple completions
- Introduced max_n_completions configuration to limit the number of completions requested per server call. - Updated chat_completion and completion methods to split requests exceeding max_n_completions into multiple calls, merging results accordingly. - Enhanced documentation for max_n_completions in ServerManagerConfig.
This commit is contained in:
parent
134a9713ce
commit
44b96c7b6c
1 changed files with 44 additions and 0 deletions
|
|
@ -25,6 +25,14 @@ class ServerManagerConfig(BaseModel):
|
|||
testing: bool = Field(
|
||||
default=False, description="If set to True, environment uses mock OpenAI data."
|
||||
)
|
||||
max_n_completions: int = Field(
|
||||
default=8,
|
||||
description=(
|
||||
"The maximum number of completions to request at once per server call. "
|
||||
"Will split any n larger than this into multiple calls. "
|
||||
"This is to help load balance servers."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ServerManager:
|
||||
|
|
@ -34,7 +42,9 @@ class ServerManager:
|
|||
server_class: APIServer = APIServer,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
max_n_completions=8,
|
||||
):
|
||||
self.max_n_completions = max_n_completions
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
|
|
@ -159,6 +169,24 @@ class ServerManager:
|
|||
await asyncio.sleep(1)
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
completions = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
n_to_use = min(total_n, self.max_n_completions)
|
||||
kwargs["n"] = n_to_use
|
||||
completions.append(self.chat_completion(**kwargs))
|
||||
total_n -= n_to_use
|
||||
completions = await asyncio.gather(
|
||||
*completions
|
||||
) # type: List[ChatCompletion]
|
||||
# merge choices into one
|
||||
out = completions[0]
|
||||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
|
|
@ -176,6 +204,22 @@ class ServerManager:
|
|||
return await self.servers[most_available_server].chat_completion(**kwargs)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
completions = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
n_to_use = min(total_n, self.max_n_completions)
|
||||
kwargs["n"] = n_to_use
|
||||
completions.append(self.completion(**kwargs))
|
||||
total_n -= n_to_use
|
||||
completions = await asyncio.gather(*completions) # type: List[Completion]
|
||||
# merge choices into one
|
||||
out = completions[0]
|
||||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue