mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Merge pull request #41 from NousResearch/workaround-provider-ignoring-n-kwarg-openai-api
Add n kwarg being ignored workaround
This commit is contained in:
commit
71e7a5ca27
5 changed files with 209 additions and 4 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import time
|
||||
import warnings
|
||||
from asyncio import exceptions
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -45,6 +46,9 @@ class OpenaiConfig(BaseModel):
|
|||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
n_kwarg_is_ignored: bool = Field(
|
||||
default=False, description="Whether the n kwarg is ignored."
|
||||
)
|
||||
|
||||
|
||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||
|
|
@ -191,6 +195,42 @@ class OpenAIServer:
|
|||
)
|
||||
return metrics_dict
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[
|
||||
self.openai.chat.completions.create(**kwargs)
|
||||
for _ in range(1, n)
|
||||
]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -201,7 +241,7 @@ class OpenAIServer:
|
|||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
|
|
@ -215,7 +255,7 @@ class OpenAIServer:
|
|||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
|
|
@ -246,6 +286,36 @@ class OpenAIServer:
|
|||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -256,7 +326,7 @@ class OpenAIServer:
|
|||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
|
|
@ -270,7 +340,7 @@ class OpenAIServer:
|
|||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue