Add dummy openai managed server

This commit is contained in:
Dakota 2026-02-04 15:16:36 -06:00
parent 462abbebf7
commit 10f651289c
4 changed files with 235 additions and 11 deletions

View file

@ -9,7 +9,10 @@ from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.managed_server import ManagedServer
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
@ -361,19 +364,28 @@ class ServerManager:
@asynccontextmanager
async def managed_server(
self, tokenizer=None
) -> AsyncGenerator[ManagedServer, None]:
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
"""
Context manager that provides a ManagedServer instance.
The ManagedServer wraps the most available server and tracks text sequences
with aligned tokens and logprobs. State is automatically cleared on exit.
For OpenAI endpoints (which don't support token IDs/logprobs), a
DummyManagedServer is returned if the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER
environment variable is set. Otherwise, a NotImplementedError is raised.
Args:
tokenizer: Optional tokenizer to use. If not provided, will attempt to
extract from server or create from model name.
Yields:
ManagedServer instance wrapping the selected server
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
the selected server
Raises:
NotImplementedError: If using OpenAI server without the
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
Example:
async with server_manager.managed_server() as managed:
@ -394,16 +406,41 @@ class ServerManager:
most_available_server = i
most_available_server_num_slots = server.sem._value
# Create ManagedServer wrapping the selected server
if isinstance(self.servers[most_available_server], OpenAIServer):
selected_server = self.servers[most_available_server]
# Handle OpenAI servers separately - they don't support token IDs/logprobs
if isinstance(selected_server, OpenAIServer):
allow_dummy = os.environ.get(
"ATROPOS_ALLOW_DUMMY_MANAGED_SERVER", ""
).lower() in (
"1",
"true",
"yes",
)
if not allow_dummy:
raise NotImplementedError(
"OpenAI endpoints do not support token IDs or logprobs required for "
"ManagedServer. If you don't need actual token-level training data and "
"are okay with dummy placeholder values, set the environment variable:\n\n"
" export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1\n\n"
"WARNING: The DummyManagedServer will return placeholder token IDs and "
"logprobs (all zeros) that are NOT suitable for training. Use only for "
"evaluation or testing workflows."
)
warnings.warn(
"Using OpenAIServer with managed_server does not allow for state tracking"
"Using DummyManagedServer with OpenAI endpoint. Token IDs and logprobs "
"will be placeholder values and are NOT suitable for training."
)
yield self.servers[most_available_server]
managed = DummyManagedServer(server=selected_server, tokenizer=tokenizer)
try:
yield managed
finally:
managed.reset()
else:
managed = ManagedServer(
server=self.servers[most_available_server], tokenizer=tokenizer
)
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
try:
yield managed