Add n kwarg being ignored workaround

This commit is contained in:
dmahan93 2025-05-12 12:06:03 -05:00
parent 706097db21
commit 1aa72d7e7e
4 changed files with 208 additions and 4 deletions

View file

@ -1,6 +1,7 @@
import asyncio
import collections
import time
import warnings
from asyncio import exceptions
from typing import Optional
@ -45,6 +46,9 @@ class OpenaiConfig(BaseModel):
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
n_kwarg_is_ignored: bool = Field(
default=False, description="Whether the n kwarg is ignored."
)
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
@ -191,6 +195,42 @@ class OpenAIServer:
)
return metrics_dict
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
)
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
completions = await self.openai.chat.completions.create(**kwargs)
else:
if "n" in kwargs:
n = kwargs["n"]
else:
n = 1
completions = await self.openai.chat.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[
self.openai.chat.completions.create(**kwargs)
for _ in range(1, n)
]
)
for c in completion_list:
completions.choices.extend(c.choices)
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -201,7 +241,7 @@ class OpenAIServer:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
completions = await self._chat_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@ -215,7 +255,7 @@ class OpenAIServer:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
completions = await self._chat_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@ -246,6 +286,36 @@ class OpenAIServer:
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
async def _completion_wrapper(self, **kwargs) -> Completion:
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(n)]
)
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
if "n" in kwargs:
n = kwargs["n"]
else:
n = 1
completions = await self.openai.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
)
for c in completion_list:
completions.choices.extend(c.choices)
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -256,7 +326,7 @@ class OpenAIServer:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
completions = await self._completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@ -270,7 +340,7 @@ class OpenAIServer:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
completions = await self._completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions