diff --git a/.env.example b/.env.example index e570b8b5..545ad9fa 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,2 @@ OPENAI_API_KEY= +OPENROUTER_API_KEY= diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index f161869e..ba2972cf 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -1,6 +1,7 @@ import asyncio import collections import time +import warnings from asyncio import exceptions from typing import Optional @@ -45,6 +46,9 @@ class OpenaiConfig(BaseModel): rolling_buffer_length: int = Field( default=1000, description="Length of the rolling buffer to store metrics." ) + n_kwarg_is_ignored: bool = Field( + default=False, description="Whether the n kwarg is ignored." + ) class AsyncSemWithAdaptiveWeight(asyncio.Semaphore): @@ -191,6 +195,42 @@ class OpenAIServer: ) return metrics_dict + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.chat.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + completions = await self.openai.chat.completions.create(**kwargs) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.chat.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[ + self.openai.chat.completions.create(**kwargs) + for _ in range(1, n) + ] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -201,7 +241,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) + completions = await self._chat_completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -215,7 +255,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) + completions = await self._chat_completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -246,6 +286,36 @@ class OpenAIServer: self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + async def _completion_wrapper(self, **kwargs) -> Completion: + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(1, n)] + ) + for c in completion_list: + completions.choices.extend(c.choices) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -256,7 +326,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) + completions = await self._completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -270,7 +340,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) + completions = await self._completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions diff --git a/atroposlib/tests/conftest.py b/atroposlib/tests/conftest.py new file mode 100644 index 00000000..d122e39d --- /dev/null +++ b/atroposlib/tests/conftest.py @@ -0,0 +1,23 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runproviders", action="store_true", default=False, help="run provider tests" + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "providers: mark test as requires providers api keys to run" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runproviders"): + # --runproviders given in cli: do not skip slow tests + return + skip_providers = pytest.mark.skip(reason="need --runproviders option to run") + for item in items: + if "providers" in item.keywords: + item.add_marker(skip_providers) diff --git a/atroposlib/tests/test_openai_api_workarounds.py b/atroposlib/tests/test_openai_api_workarounds.py new file mode 100644 index 00000000..50e5c4f3 --- /dev/null +++ b/atroposlib/tests/test_openai_api_workarounds.py @@ -0,0 +1,110 @@ +import asyncio +import os + +import dotenv +import pytest + +from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_discovery(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = OpenaiConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + ) + assert not config.n_kwarg_is_ignored, "n kwarg is not ignored by default" + n = 4 + server = OpenAIServer( + config=config, + ) + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + print(len(response.choices), n) + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_use(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = OpenaiConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=True, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_supported(): + dotenv.load_dotenv() + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + pytest.skip("OPENAI_API_KEY not set") + config = OpenaiConfig( + model_name="gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=False, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert ( + not server.config.n_kwarg_is_ignored + ), "n kwarg should be used with supported models" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}"