mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
265 lines
10 KiB
Python
265 lines
10 KiB
Python
import asyncio
|
|
import warnings
|
|
|
|
import aiohttp
|
|
import openai
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
from openai.types.completion import Completion
|
|
from pydantic_cli import FailedExecutionException
|
|
|
|
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
|
from atroposlib.envs.server_handling.server_baseline import (
|
|
APIServer,
|
|
APIServerConfig,
|
|
ReasoningConfig,
|
|
)
|
|
|
|
|
|
class OpenAIServer(APIServer):
|
|
"""
|
|
OpenAI server handling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: APIServerConfig,
|
|
reasoning_config: ReasoningConfig = None,
|
|
):
|
|
self.openai = openai.AsyncClient(
|
|
api_key=config.api_key,
|
|
base_url=config.base_url,
|
|
timeout=config.timeout,
|
|
)
|
|
super().__init__(config, reasoning_config=reasoning_config)
|
|
|
|
async def check_server_status_task(self, chat_completion: bool = True):
|
|
while True:
|
|
try:
|
|
if chat_completion:
|
|
await self.openai.chat.completions.create(
|
|
model=self.config.model_name,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
max_tokens=1,
|
|
)
|
|
else:
|
|
await self.openai.completions.create(
|
|
model=self.config.model_name,
|
|
prompt="hi",
|
|
max_tokens=1,
|
|
)
|
|
self.server_healthy = True
|
|
except (
|
|
aiohttp.ClientError,
|
|
openai.OpenAIError,
|
|
openai.APITimeoutError,
|
|
Exception,
|
|
):
|
|
self.server_healthy = False
|
|
await asyncio.sleep(1)
|
|
|
|
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
|
"""
|
|
Wrapper for the chat completion using the openai client.
|
|
"""
|
|
assert (
|
|
kwargs.get("model", None) is not None
|
|
), "Model is required for chat completion!"
|
|
assert (
|
|
kwargs.get("messages", None) is not None
|
|
), "Messages are required for chat completion!"
|
|
if self.config.n_kwarg_is_ignored:
|
|
n = kwargs.pop("n", 1)
|
|
completion_list = await asyncio.gather(
|
|
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
|
)
|
|
completions = completion_list[0]
|
|
if n > 1:
|
|
for c in completion_list[1:]:
|
|
completions.choices.extend(c.choices)
|
|
else:
|
|
completions = await self.openai.chat.completions.create(**kwargs)
|
|
else:
|
|
if "n" in kwargs:
|
|
n = kwargs["n"]
|
|
else:
|
|
n = 1
|
|
completions = await self.openai.chat.completions.create(**kwargs)
|
|
if len(completions.choices) != n:
|
|
if len(completions.choices) != 1:
|
|
raise ValueError(
|
|
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
|
)
|
|
else:
|
|
warnings.warn("n kwarg is ignored by the API, setting to True")
|
|
self.config.n_kwarg_is_ignored = True
|
|
completion_list = await asyncio.gather(
|
|
*[
|
|
self.openai.chat.completions.create(**kwargs)
|
|
for _ in range(1, n)
|
|
]
|
|
)
|
|
for c in completion_list:
|
|
completions.choices.extend(c.choices)
|
|
return completions
|
|
|
|
async def _completion_wrapper(self, **kwargs) -> Completion:
|
|
"""
|
|
Wrapper for the completion using the openai client.
|
|
"""
|
|
assert (
|
|
kwargs.get("model", None) is not None
|
|
), "Model is required for completion!"
|
|
assert (
|
|
kwargs.get("prompt", None) is not None
|
|
), "Prompt is required for completion!"
|
|
if self.config.n_kwarg_is_ignored:
|
|
n = kwargs.pop("n", 1)
|
|
completion_list = await asyncio.gather(
|
|
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
|
)
|
|
completions = completion_list[0]
|
|
if n > 1:
|
|
for c in completion_list[1:]:
|
|
completions.choices.extend(c.choices)
|
|
else:
|
|
if "n" in kwargs:
|
|
n = kwargs["n"]
|
|
else:
|
|
n = 1
|
|
completions = await self.openai.completions.create(**kwargs)
|
|
if len(completions.choices) != n:
|
|
if len(completions.choices) != 1:
|
|
raise ValueError(
|
|
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
|
)
|
|
else:
|
|
warnings.warn("n kwarg is ignored by the API, setting to True")
|
|
self.config.n_kwarg_is_ignored = True
|
|
completion_list = await asyncio.gather(
|
|
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
|
)
|
|
for c in completion_list:
|
|
completions.choices.extend(c.choices)
|
|
return completions
|
|
|
|
async def _tokens_and_logprobs_completion_wrapper(
|
|
self, **kwargs
|
|
) -> tuple[list, list, list, list]:
|
|
"""
|
|
Wrapper for the tokens and logprobs completion using the openai client.
|
|
"""
|
|
raise NotImplementedError(
|
|
"Tokens and logprobs not supported by base OpenAI API, use specific API servers."
|
|
)
|
|
|
|
|
|
def resolve_openai_configs(
|
|
default_server_configs,
|
|
openai_config_dict,
|
|
yaml_config,
|
|
cli_passed_flags,
|
|
logger,
|
|
):
|
|
"""
|
|
Helper to resolve the final server_configs, handling single, multiple servers, and overrides.
|
|
"""
|
|
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
|
|
|
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}")
|
|
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
|
|
|
|
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
|
openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None)
|
|
openai_cli_config = {
|
|
k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix)
|
|
}
|
|
|
|
print(f"[RESOLVE DEBUG] openai_cli_config = {openai_cli_config}")
|
|
|
|
is_multi_server_yaml = (
|
|
isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2
|
|
)
|
|
is_multi_server_default = (
|
|
(not is_multi_server_yaml)
|
|
and isinstance(default_server_configs, list)
|
|
and len(default_server_configs) >= 2
|
|
)
|
|
|
|
print(
|
|
"[RESOLVE DEBUG] is_multi_server_yaml="
|
|
f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}"
|
|
)
|
|
print(
|
|
"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = "
|
|
f"{isinstance(default_server_configs, ServerBaseline)}"
|
|
)
|
|
|
|
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
|
|
raise FailedExecutionException(
|
|
message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported "
|
|
f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' "
|
|
"or a default list with length >= 2).",
|
|
exit_code=2,
|
|
)
|
|
|
|
if is_multi_server_yaml:
|
|
print("[RESOLVE DEBUG] Taking multi-server YAML path")
|
|
logger.info(
|
|
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
|
|
)
|
|
try:
|
|
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
|
|
except Exception as e:
|
|
raise FailedExecutionException(
|
|
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
|
) from e
|
|
elif isinstance(default_server_configs, APIServerConfig):
|
|
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
|
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
|
|
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).")
|
|
try:
|
|
final_openai_config = APIServerConfig(**openai_config_dict)
|
|
except Exception as e:
|
|
raise FailedExecutionException(
|
|
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
|
f"Merged Dict: {openai_config_dict}"
|
|
) from e
|
|
server_configs = final_openai_config
|
|
elif isinstance(default_server_configs, ServerBaseline):
|
|
# Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible
|
|
print("[RESOLVE DEBUG] Taking ServerBaseline path")
|
|
logger.info("Using ServerBaseline configuration.")
|
|
server_configs = default_server_configs
|
|
elif is_multi_server_default:
|
|
print("[RESOLVE DEBUG] Taking multi-server default path")
|
|
logger.info("Using default multi-server configuration (length >= 2).")
|
|
server_configs = default_server_configs
|
|
else:
|
|
print("[RESOLVE DEBUG] Taking single server merged path (fallback)")
|
|
logger.info(
|
|
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
|
)
|
|
try:
|
|
final_openai_config = APIServerConfig(**openai_config_dict)
|
|
except Exception as e:
|
|
raise FailedExecutionException(
|
|
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
|
f"Merged Dict: {openai_config_dict}"
|
|
) from e
|
|
|
|
print(f"[RESOLVE DEBUG] final_openai_config = {final_openai_config}")
|
|
if isinstance(default_server_configs, APIServerConfig):
|
|
print("[RESOLVE DEBUG] Returning final_openai_config directly")
|
|
server_configs = final_openai_config
|
|
elif isinstance(default_server_configs, list):
|
|
print("[RESOLVE DEBUG] Returning [final_openai_config]")
|
|
server_configs = [final_openai_config]
|
|
else:
|
|
logger.warning(
|
|
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
|
|
f"Proceeding with single OpenAI server configuration based on merged settings."
|
|
)
|
|
server_configs = [final_openai_config]
|
|
|
|
print(f"[RESOLVE DEBUG] Returning server_configs = {server_configs}")
|
|
return server_configs
|