mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
switch eval to use managed server adapter impl. moved managed server
adapter
This commit is contained in:
parent
32d12c05c3
commit
5a20abdce7
4 changed files with 253 additions and 270 deletions
|
|
@ -46,65 +46,11 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for verifiers.
|
||||
|
||||
Implements the subset of AsyncOpenAI interface that verifiers uses:
|
||||
- client.chat.completions.create()
|
||||
- client.completions.create()
|
||||
- client.base_url
|
||||
"""
|
||||
|
||||
def __init__(self, managed_server: ManagedServer, base_url: str):
|
||||
self._managed = managed_server
|
||||
self.base_url = base_url
|
||||
self.chat = self._ChatNamespace(self._managed)
|
||||
self.completions = self._CompletionsNamespace(self._managed)
|
||||
|
||||
class _ChatNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
|
||||
|
||||
class _ChatCompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
logger.info(
|
||||
"ManagedServerAdapter.chat.completions.create called with model=%s",
|
||||
kwargs.get("model"),
|
||||
)
|
||||
result = await self._managed.chat_completion(**kwargs)
|
||||
logger.info("ManagedServerAdapter.chat.completions.create completed")
|
||||
return result
|
||||
|
||||
class _CompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.completion(**kwargs)
|
||||
|
||||
async def post(self, path: str, body: dict, cast_to: type):
|
||||
raise NotImplementedError(
|
||||
f"ManagedServerAdapter does not support post() for path '{path}'. "
|
||||
"This is used for vLLM interleaved rollouts. Use standard chat completions."
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"ManagedServerAdapter does not support copy(). "
|
||||
"This is used for vLLM tokenization endpoints."
|
||||
)
|
||||
|
||||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
vf_env_name: str = ""
|
||||
env_args: str = "{}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue