Add max_n_completions parameter to ServerManager for handling multiple completions

- Introduced max_n_completions configuration to limit the number of completions requested per server call.
- Updated chat_completion and completion methods to split requests exceeding max_n_completions into multiple calls, merging results accordingly.
- Enhanced documentation for max_n_completions in ServerManagerConfig.
This commit is contained in:
dmahan93 2025-06-02 11:11:55 -05:00
parent 134a9713ce
commit 44b96c7b6c

View file

@ -25,6 +25,14 @@ class ServerManagerConfig(BaseModel):
testing: bool = Field(
default=False, description="If set to True, environment uses mock OpenAI data."
)
max_n_completions: int = Field(
default=8,
description=(
"The maximum number of completions to request at once per server call. "
"Will split any n larger than this into multiple calls. "
"This is to help load balance servers."
),
)
class ServerManager:
@ -34,7 +42,9 @@ class ServerManager:
server_class: APIServer = APIServer,
slurm=False,
testing=False,
max_n_completions=8,
):
self.max_n_completions = max_n_completions
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
@ -159,6 +169,24 @@ class ServerManager:
await asyncio.sleep(1)
async def chat_completion(self, **kwargs) -> ChatCompletion:
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
completions = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
completions.append(self.chat_completion(**kwargs))
total_n -= n_to_use
completions = await asyncio.gather(
*completions
) # type: List[ChatCompletion]
# merge choices into one
out = completions[0]
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
@ -176,6 +204,22 @@ class ServerManager:
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
completions = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
completions.append(self.completion(**kwargs))
total_n -= n_to_use
completions = await asyncio.gather(*completions) # type: List[Completion]
# merge choices into one
out = completions[0]
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1