add managed server to make grabbing logprobs easier w/ tokenized items

This commit is contained in:
dmahan93 2025-10-24 13:09:46 -07:00
parent 312f8859e3
commit 7bf4cfbf80
6 changed files with 1138 additions and 29 deletions

View file

@ -8,6 +8,7 @@ from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.managed_server import ManagedServer
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
@ -308,3 +309,50 @@ class ServerManager:
yield self.servers[most_available_server]
finally:
pass
@asynccontextmanager
async def managed_server(
self, tokenizer=None
) -> AsyncGenerator[ManagedServer, None]:
"""
Context manager that provides a ManagedServer instance.
The ManagedServer wraps the most available server and tracks text sequences
with aligned tokens and logprobs. State is automatically cleared on exit.
Args:
tokenizer: Optional tokenizer to use. If not provided, will attempt to
extract from server or create from model name.
Yields:
ManagedServer instance wrapping the selected server
Example:
async with server_manager.managed_server() as managed:
response = await managed.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
n=2
)
state = managed.get_state()
# Process state...
# State is automatically cleared when exiting context
"""
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value
# Create ManagedServer wrapping the selected server
managed = ManagedServer(
server=self.servers[most_available_server], tokenizer=tokenizer
)
try:
yield managed
finally:
# Clean up: reset tracked sequences
managed.reset()