mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add n kwarg being ignored workaround
This commit is contained in:
parent
706097db21
commit
1aa72d7e7e
4 changed files with 208 additions and 4 deletions
|
|
@ -1 +1,2 @@
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
|
OPENROUTER_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from asyncio import exceptions
|
from asyncio import exceptions
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -45,6 +46,9 @@ class OpenaiConfig(BaseModel):
|
||||||
rolling_buffer_length: int = Field(
|
rolling_buffer_length: int = Field(
|
||||||
default=1000, description="Length of the rolling buffer to store metrics."
|
default=1000, description="Length of the rolling buffer to store metrics."
|
||||||
)
|
)
|
||||||
|
n_kwarg_is_ignored: bool = Field(
|
||||||
|
default=False, description="Whether the n kwarg is ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||||
|
|
@ -191,6 +195,42 @@ class OpenAIServer:
|
||||||
)
|
)
|
||||||
return metrics_dict
|
return metrics_dict
|
||||||
|
|
||||||
|
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||||
|
if self.config.n_kwarg_is_ignored:
|
||||||
|
n = kwargs.pop("n", 1)
|
||||||
|
completion_list = await asyncio.gather(
|
||||||
|
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
||||||
|
)
|
||||||
|
completions = completion_list[0]
|
||||||
|
if n > 1:
|
||||||
|
for c in completion_list[1:]:
|
||||||
|
completions.choices.extend(c.choices)
|
||||||
|
else:
|
||||||
|
completions = await self.openai.chat.completions.create(**kwargs)
|
||||||
|
else:
|
||||||
|
if "n" in kwargs:
|
||||||
|
n = kwargs["n"]
|
||||||
|
else:
|
||||||
|
n = 1
|
||||||
|
completions = await self.openai.chat.completions.create(**kwargs)
|
||||||
|
if len(completions.choices) != n:
|
||||||
|
if len(completions.choices) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||||
|
self.config.n_kwarg_is_ignored = True
|
||||||
|
completion_list = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.openai.chat.completions.create(**kwargs)
|
||||||
|
for _ in range(1, n)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for c in completion_list:
|
||||||
|
completions.choices.extend(c.choices)
|
||||||
|
return completions
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||||
)
|
)
|
||||||
|
|
@ -201,7 +241,7 @@ class OpenAIServer:
|
||||||
if stat_dict.get("start", None) is None:
|
if stat_dict.get("start", None) is None:
|
||||||
stat_dict["start"] = time.time()
|
stat_dict["start"] = time.time()
|
||||||
stat_dict["attempts"] += 1
|
stat_dict["attempts"] += 1
|
||||||
completions = await self.openai.chat.completions.create(**kwargs)
|
completions = await self._chat_completion_wrapper(**kwargs)
|
||||||
stat_dict["end"] = time.time()
|
stat_dict["end"] = time.time()
|
||||||
return completions
|
return completions
|
||||||
|
|
||||||
|
|
@ -215,7 +255,7 @@ class OpenAIServer:
|
||||||
if stat_dict.get("start", None) is None:
|
if stat_dict.get("start", None) is None:
|
||||||
stat_dict["start"] = time.time()
|
stat_dict["start"] = time.time()
|
||||||
stat_dict["attempts"] += 1
|
stat_dict["attempts"] += 1
|
||||||
completions = await self.openai.chat.completions.create(**kwargs)
|
completions = await self._chat_completion_wrapper(**kwargs)
|
||||||
stat_dict["end"] = time.time()
|
stat_dict["end"] = time.time()
|
||||||
return completions
|
return completions
|
||||||
|
|
||||||
|
|
@ -246,6 +286,36 @@ class OpenAIServer:
|
||||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||||
return ret_data
|
return ret_data
|
||||||
|
|
||||||
|
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||||
|
if self.config.n_kwarg_is_ignored:
|
||||||
|
n = kwargs.pop("n", 1)
|
||||||
|
completion_list = await asyncio.gather(
|
||||||
|
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
||||||
|
)
|
||||||
|
completions = completion_list[0]
|
||||||
|
if n > 1:
|
||||||
|
for c in completion_list[1:]:
|
||||||
|
completions.choices.extend(c.choices)
|
||||||
|
else:
|
||||||
|
if "n" in kwargs:
|
||||||
|
n = kwargs["n"]
|
||||||
|
else:
|
||||||
|
n = 1
|
||||||
|
completions = await self.openai.completions.create(**kwargs)
|
||||||
|
if len(completions.choices) != n:
|
||||||
|
if len(completions.choices) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||||
|
self.config.n_kwarg_is_ignored = True
|
||||||
|
completion_list = await asyncio.gather(
|
||||||
|
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
||||||
|
)
|
||||||
|
for c in completion_list:
|
||||||
|
completions.choices.extend(c.choices)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||||
)
|
)
|
||||||
|
|
@ -256,7 +326,7 @@ class OpenAIServer:
|
||||||
if stat_dict.get("start", None) is None:
|
if stat_dict.get("start", None) is None:
|
||||||
stat_dict["start"] = time.time()
|
stat_dict["start"] = time.time()
|
||||||
stat_dict["attempts"] += 1
|
stat_dict["attempts"] += 1
|
||||||
completions = await self.openai.completions.create(**kwargs)
|
completions = await self._completion_wrapper(**kwargs)
|
||||||
stat_dict["end"] = time.time()
|
stat_dict["end"] = time.time()
|
||||||
return completions
|
return completions
|
||||||
|
|
||||||
|
|
@ -270,7 +340,7 @@ class OpenAIServer:
|
||||||
if stat_dict.get("start", None) is None:
|
if stat_dict.get("start", None) is None:
|
||||||
stat_dict["start"] = time.time()
|
stat_dict["start"] = time.time()
|
||||||
stat_dict["attempts"] += 1
|
stat_dict["attempts"] += 1
|
||||||
completions = await self.openai.completions.create(**kwargs)
|
completions = await self._completion_wrapper(**kwargs)
|
||||||
stat_dict["end"] = time.time()
|
stat_dict["end"] = time.time()
|
||||||
return completions
|
return completions
|
||||||
|
|
||||||
|
|
|
||||||
23
atroposlib/tests/conftest.py
Normal file
23
atroposlib/tests/conftest.py
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--runproviders", action="store_true", default=False, help="run provider tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers", "providers: mark test as requires providers api keys to run"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
if config.getoption("--runproviders"):
|
||||||
|
# --runproviders given in cli: do not skip slow tests
|
||||||
|
return
|
||||||
|
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
|
||||||
|
for item in items:
|
||||||
|
if "providers" in item.keywords:
|
||||||
|
item.add_marker(skip_providers)
|
||||||
110
atroposlib/tests/test_openai_api_workarounds.py
Normal file
110
atroposlib/tests/test_openai_api_workarounds.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.providers
|
||||||
|
def test_openai_api_n_kwarg_ignore_discovery():
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if not openrouter_api_key:
|
||||||
|
pytest.skip("OPENROUTER_API_KEY not set")
|
||||||
|
config = OpenaiConfig(
|
||||||
|
api_key=openrouter_api_key,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model_name="openai/gpt-4.1-nano",
|
||||||
|
timeout=1200,
|
||||||
|
num_max_requests_at_once=512,
|
||||||
|
num_requests_for_eval=64,
|
||||||
|
rolling_buffer_length=1024,
|
||||||
|
)
|
||||||
|
assert not config.n_kwarg_is_ignored, "n kwarg is not ignored by default"
|
||||||
|
n = 4
|
||||||
|
server = OpenAIServer(
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
response = asyncio.run(
|
||||||
|
server.chat_completion(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
],
|
||||||
|
n=n,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery"
|
||||||
|
print(len(response.choices), n)
|
||||||
|
assert (
|
||||||
|
len(response.choices) == n
|
||||||
|
), f"Expected {n} responses, got {len(response.choices)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.providers
|
||||||
|
def test_openai_api_n_kwarg_ignore_use():
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if not openrouter_api_key:
|
||||||
|
pytest.skip("OPENROUTER_API_KEY not set")
|
||||||
|
config = OpenaiConfig(
|
||||||
|
api_key=openrouter_api_key,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model_name="openai/gpt-4.1-nano",
|
||||||
|
timeout=1200,
|
||||||
|
num_max_requests_at_once=512,
|
||||||
|
num_requests_for_eval=64,
|
||||||
|
rolling_buffer_length=1024,
|
||||||
|
n_kwarg_is_ignored=True,
|
||||||
|
)
|
||||||
|
server = OpenAIServer(
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
n = 4
|
||||||
|
response = asyncio.run(
|
||||||
|
server.chat_completion(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
],
|
||||||
|
n=n,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery"
|
||||||
|
assert (
|
||||||
|
len(response.choices) == n
|
||||||
|
), f"Expected {n} responses, got {len(response.choices)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.providers
|
||||||
|
def test_openai_api_n_kwarg_supported():
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not openai_api_key:
|
||||||
|
pytest.skip("OPENAI_API_KEY not set")
|
||||||
|
config = OpenaiConfig(
|
||||||
|
model_name="gpt-4.1-nano",
|
||||||
|
timeout=1200,
|
||||||
|
num_max_requests_at_once=512,
|
||||||
|
num_requests_for_eval=64,
|
||||||
|
rolling_buffer_length=1024,
|
||||||
|
n_kwarg_is_ignored=False,
|
||||||
|
)
|
||||||
|
server = OpenAIServer(
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
n = 4
|
||||||
|
response = asyncio.run(
|
||||||
|
server.chat_completion(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
],
|
||||||
|
n=n,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
not server.config.n_kwarg_is_ignored
|
||||||
|
), "n kwarg should be used with supported models"
|
||||||
|
assert (
|
||||||
|
len(response.choices) == n
|
||||||
|
), f"Expected {n} responses, got {len(response.choices)}"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue