Add dummy openai managed server

This commit is contained in:
Dakota 2026-02-04 15:16:36 -06:00
parent 462abbebf7
commit 10f651289c
4 changed files with 235 additions and 11 deletions

View file

@ -4,7 +4,9 @@
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
**Server Compatibility:** ManagedServer works with all Atropos server types - `OpenAIServer`, `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
> **⚠️ OpenAI Endpoints:** OpenAI's API does not expose token IDs or detailed logprobs required for full ManagedServer functionality. See [OpenAI Endpoint Limitations](#openai-endpoint-limitations) for details and workarounds.
### Why Use ManagedServer?
@ -684,6 +686,74 @@ The context manager:
# Turn 2 prompt must be: "Hello World..." (exact prefix match)
```
## OpenAI Endpoint Limitations
OpenAI's API does not expose token IDs or detailed logprobs in the same way that vLLM, SGLang, and other self-hosted inference servers do. This means **ManagedServer cannot provide accurate token-level training data** when using OpenAI endpoints.
### Default Behavior
By default, attempting to use `managed_server()` with an `OpenAIServer` will raise a `NotImplementedError`:
```python
async with self.server.managed_server() as managed:
# Raises NotImplementedError if server is OpenAIServer
...
```
The error message will explain the limitation and how to opt-in if you don't need real token data.
### DummyManagedServer (Opt-in)
If you're using OpenAI endpoints for **evaluation or testing** (not training) and don't need actual token IDs or logprobs, you can opt-in to use `DummyManagedServer` by setting an environment variable:
```bash
export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
```
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
- Provides the same interface as `ManagedServer`
- Returns **fixed placeholder values** for tokens and logprobs:
- `tokens`: `[1, 2, 3]`
- `masked_tokens`: `[-100, 2, 3]`
- `logprobs`: `[-0.5, -0.5, -0.5]`
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
### When to Use DummyManagedServer
✅ **Appropriate uses:**
- Testing environment logic without needing real token data
- Evaluation workflows where you only need completion text
- Prototyping before switching to a self-hosted inference server
❌ **Not appropriate for:**
- Training (tokens and logprobs are meaningless placeholders)
- Any workflow that depends on accurate token-level information
### Example
```python
import os
# Opt-in to dummy managed server for OpenAI
os.environ["ATROPOS_ALLOW_DUMMY_MANAGED_SERVER"] = "1"
# Now this works with OpenAI endpoints
async with self.server.managed_server() as managed:
response = await managed.chat_completion(messages=messages, n=4)
state = managed.get_state()
nodes = state["nodes"]
# nodes contain placeholder token data - DO NOT use for training
for node in nodes:
print(node.full_text) # Real completion text
print(node.tokens) # [1, 2, 3] - placeholder!
print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder!
```
### Recommendation
For training workloads, use a self-hosted inference server (`VLLMServer`, `SGLangServer`, or `TrlVllmServer`) that provides full token and logprob access. OpenAI endpoints are best suited for evaluation, testing, or workflows that only need completion text.
## Additional Resources
- [ManagedServer Source Code](managed_server.py)

View file

@ -2,6 +2,12 @@
This module provides server abstraction layers for different LLM inference backends.
## ManagedServer
For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_SERVER.md).
> **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details.
## Reasoning Model Support
The `ReasoningConfig` class enables support for reasoning/thinking models across different providers.

View file

@ -511,6 +511,117 @@ class ManagedServer:
self.current_nodes.clear()
class DummyManagedServer:
"""
A simple managed server wrapper for OpenAI endpoints that don't support token IDs/logprobs.
Uses fixed placeholder values for tokens and logprobs. NOT suitable for training.
"""
# Fixed dummy values
DUMMY_TOKENS = [1, 2, 3]
DUMMY_MASKED_TOKENS = [-100, 2, 3]
DUMMY_LOGPROBS = [-0.5, -0.5, -0.5]
def __init__(
self,
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
):
self.server = server
self.track_tree = track_tree
# tokenizer is accepted but ignored - we don't tokenize anything
if track_tree:
self.sequences: Dict[str, SequenceNode] = {}
else:
self.current_nodes: List[SequenceNode] = []
def _messages_to_text(self, messages: List[Dict[str, str]]) -> str:
"""Convert messages to simple text format."""
return "\n\n".join([f"{m['role']}:{m['content']}" for m in messages])
def _create_dummy_node(
self,
full_text: str,
finish_reason: str = "stop",
) -> SequenceNode:
"""Create a sequence node with fixed dummy values."""
return SequenceNode(
full_text=full_text,
tokens=self.DUMMY_TOKENS,
masked_tokens=self.DUMMY_MASKED_TOKENS,
logprobs=self.DUMMY_LOGPROBS,
metadata={"finish_reason": finish_reason, "dummy_tokens": True},
)
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""Make a chat completion call and track with dummy tokens."""
messages = kwargs.get("messages", [])
response = await self.server.chat_completion(**kwargs)
for choice in response.choices:
completion_content = choice.message.content or ""
# Append assistant response to messages for full_text
all_messages = messages + [
{"role": "assistant", "content": completion_content}
]
full_text = self._messages_to_text(all_messages)
node = self._create_dummy_node(
full_text=full_text,
finish_reason=choice.finish_reason or "stop",
)
if self.track_tree:
self.sequences[node.full_text] = node
else:
self.current_nodes.append(node)
return response
async def completion(self, **kwargs) -> Completion:
"""Make a completion call and track with dummy tokens."""
prompt = kwargs.get("prompt", "")
response = await self.server.completion(**kwargs)
for choice in response.choices:
completion_text = choice.text or ""
full_text = f"{prompt}{completion_text}"
node = self._create_dummy_node(
full_text=full_text,
finish_reason=choice.finish_reason or "stop",
)
if self.track_tree:
self.sequences[node.full_text] = node
else:
self.current_nodes.append(node)
return response
def get_state(self) -> Dict[str, Any]:
"""Get the current state of tracked sequences."""
if self.track_tree:
return {
"sequences": self.sequences.copy(),
"tree": self.sequences.copy(),
}
else:
return {"nodes": self.current_nodes.copy()}
def reset(self):
"""Clear all tracked sequences."""
if self.track_tree:
self.sequences.clear()
else:
self.current_nodes.clear()
class ManagedServerAdapter:
"""
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.

View file

@ -9,7 +9,10 @@ from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.managed_server import ManagedServer
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
@ -361,19 +364,28 @@ class ServerManager:
@asynccontextmanager
async def managed_server(
self, tokenizer=None
) -> AsyncGenerator[ManagedServer, None]:
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
"""
Context manager that provides a ManagedServer instance.
The ManagedServer wraps the most available server and tracks text sequences
with aligned tokens and logprobs. State is automatically cleared on exit.
For OpenAI endpoints (which don't support token IDs/logprobs), a
DummyManagedServer is returned if the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER
environment variable is set. Otherwise, a NotImplementedError is raised.
Args:
tokenizer: Optional tokenizer to use. If not provided, will attempt to
extract from server or create from model name.
Yields:
ManagedServer instance wrapping the selected server
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
the selected server
Raises:
NotImplementedError: If using OpenAI server without the
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
Example:
async with server_manager.managed_server() as managed:
@ -394,16 +406,41 @@ class ServerManager:
most_available_server = i
most_available_server_num_slots = server.sem._value
# Create ManagedServer wrapping the selected server
if isinstance(self.servers[most_available_server], OpenAIServer):
selected_server = self.servers[most_available_server]
# Handle OpenAI servers separately - they don't support token IDs/logprobs
if isinstance(selected_server, OpenAIServer):
allow_dummy = os.environ.get(
"ATROPOS_ALLOW_DUMMY_MANAGED_SERVER", ""
).lower() in (
"1",
"true",
"yes",
)
if not allow_dummy:
raise NotImplementedError(
"OpenAI endpoints do not support token IDs or logprobs required for "
"ManagedServer. If you don't need actual token-level training data and "
"are okay with dummy placeholder values, set the environment variable:\n\n"
" export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1\n\n"
"WARNING: The DummyManagedServer will return placeholder token IDs and "
"logprobs (all zeros) that are NOT suitable for training. Use only for "
"evaluation or testing workflows."
)
warnings.warn(
"Using OpenAIServer with managed_server does not allow for state tracking"
"Using DummyManagedServer with OpenAI endpoint. Token IDs and logprobs "
"will be placeholder values and are NOT suitable for training."
)
yield self.servers[most_available_server]
managed = DummyManagedServer(server=selected_server, tokenizer=tokenizer)
try:
yield managed
finally:
managed.reset()
else:
managed = ManagedServer(
server=self.servers[most_available_server], tokenizer=tokenizer
)
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
try:
yield managed