mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Add dummy openai managed server
This commit is contained in:
parent
462abbebf7
commit
10f651289c
4 changed files with 235 additions and 11 deletions
|
|
@ -9,7 +9,10 @@ from openai.types.chat.chat_completion import ChatCompletion
|
|||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.managed_server import (
|
||||
DummyManagedServer,
|
||||
ManagedServer,
|
||||
)
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
|
|
@ -361,19 +364,28 @@ class ServerManager:
|
|||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[ManagedServer, None]:
|
||||
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
The ManagedServer wraps the most available server and tracks text sequences
|
||||
with aligned tokens and logprobs. State is automatically cleared on exit.
|
||||
|
||||
For OpenAI endpoints (which don't support token IDs/logprobs), a
|
||||
DummyManagedServer is returned if the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER
|
||||
environment variable is set. Otherwise, a NotImplementedError is raised.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
|
||||
Yields:
|
||||
ManagedServer instance wrapping the selected server
|
||||
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
|
||||
the selected server
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If using OpenAI server without the
|
||||
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server() as managed:
|
||||
|
|
@ -394,16 +406,41 @@ class ServerManager:
|
|||
most_available_server = i
|
||||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
if isinstance(self.servers[most_available_server], OpenAIServer):
|
||||
selected_server = self.servers[most_available_server]
|
||||
|
||||
# Handle OpenAI servers separately - they don't support token IDs/logprobs
|
||||
if isinstance(selected_server, OpenAIServer):
|
||||
allow_dummy = os.environ.get(
|
||||
"ATROPOS_ALLOW_DUMMY_MANAGED_SERVER", ""
|
||||
).lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
if not allow_dummy:
|
||||
raise NotImplementedError(
|
||||
"OpenAI endpoints do not support token IDs or logprobs required for "
|
||||
"ManagedServer. If you don't need actual token-level training data and "
|
||||
"are okay with dummy placeholder values, set the environment variable:\n\n"
|
||||
" export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1\n\n"
|
||||
"WARNING: The DummyManagedServer will return placeholder token IDs and "
|
||||
"logprobs (all zeros) that are NOT suitable for training. Use only for "
|
||||
"evaluation or testing workflows."
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"Using OpenAIServer with managed_server does not allow for state tracking"
|
||||
"Using DummyManagedServer with OpenAI endpoint. Token IDs and logprobs "
|
||||
"will be placeholder values and are NOT suitable for training."
|
||||
)
|
||||
yield self.servers[most_available_server]
|
||||
managed = DummyManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue