mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add dummy openai managed server
This commit is contained in:
parent
462abbebf7
commit
10f651289c
4 changed files with 235 additions and 11 deletions
|
|
@ -4,7 +4,9 @@
|
|||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
**Server Compatibility:** ManagedServer works with all Atropos server types - `OpenAIServer`, `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
|
||||
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
|
||||
|
||||
> **⚠️ OpenAI Endpoints:** OpenAI's API does not expose token IDs or detailed logprobs required for full ManagedServer functionality. See [OpenAI Endpoint Limitations](#openai-endpoint-limitations) for details and workarounds.
|
||||
|
||||
### Why Use ManagedServer?
|
||||
|
||||
|
|
@ -684,6 +686,74 @@ The context manager:
|
|||
# Turn 2 prompt must be: "Hello World..." (exact prefix match)
|
||||
```
|
||||
|
||||
## OpenAI Endpoint Limitations
|
||||
|
||||
OpenAI's API does not expose token IDs or detailed logprobs in the same way that vLLM, SGLang, and other self-hosted inference servers do. This means **ManagedServer cannot provide accurate token-level training data** when using OpenAI endpoints.
|
||||
|
||||
### Default Behavior
|
||||
|
||||
By default, attempting to use `managed_server()` with an `OpenAIServer` will raise a `NotImplementedError`:
|
||||
|
||||
```python
|
||||
async with self.server.managed_server() as managed:
|
||||
# Raises NotImplementedError if server is OpenAIServer
|
||||
...
|
||||
```
|
||||
|
||||
The error message will explain the limitation and how to opt-in if you don't need real token data.
|
||||
|
||||
### DummyManagedServer (Opt-in)
|
||||
|
||||
If you're using OpenAI endpoints for **evaluation or testing** (not training) and don't need actual token IDs or logprobs, you can opt-in to use `DummyManagedServer` by setting an environment variable:
|
||||
|
||||
```bash
|
||||
export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
|
||||
```
|
||||
|
||||
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
|
||||
- Provides the same interface as `ManagedServer`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs:
|
||||
- `tokens`: `[1, 2, 3]`
|
||||
- `masked_tokens`: `[-100, 2, 3]`
|
||||
- `logprobs`: `[-0.5, -0.5, -0.5]`
|
||||
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
|
||||
|
||||
### When to Use DummyManagedServer
|
||||
|
||||
✅ **Appropriate uses:**
|
||||
- Testing environment logic without needing real token data
|
||||
- Evaluation workflows where you only need completion text
|
||||
- Prototyping before switching to a self-hosted inference server
|
||||
|
||||
❌ **Not appropriate for:**
|
||||
- Training (tokens and logprobs are meaningless placeholders)
|
||||
- Any workflow that depends on accurate token-level information
|
||||
|
||||
### Example
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Opt-in to dummy managed server for OpenAI
|
||||
os.environ["ATROPOS_ALLOW_DUMMY_MANAGED_SERVER"] = "1"
|
||||
|
||||
# Now this works with OpenAI endpoints
|
||||
async with self.server.managed_server() as managed:
|
||||
response = await managed.chat_completion(messages=messages, n=4)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# nodes contain placeholder token data - DO NOT use for training
|
||||
for node in nodes:
|
||||
print(node.full_text) # Real completion text
|
||||
print(node.tokens) # [1, 2, 3] - placeholder!
|
||||
print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder!
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
||||
For training workloads, use a self-hosted inference server (`VLLMServer`, `SGLangServer`, or `TrlVllmServer`) that provides full token and logprob access. OpenAI endpoints are best suited for evaluation, testing, or workflows that only need completion text.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [ManagedServer Source Code](managed_server.py)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
This module provides server abstraction layers for different LLM inference backends.
|
||||
|
||||
## ManagedServer
|
||||
|
||||
For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_SERVER.md).
|
||||
|
||||
> **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details.
|
||||
|
||||
## Reasoning Model Support
|
||||
|
||||
The `ReasoningConfig` class enables support for reasoning/thinking models across different providers.
|
||||
|
|
|
|||
|
|
@ -511,6 +511,117 @@ class ManagedServer:
|
|||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
A simple managed server wrapper for OpenAI endpoints that don't support token IDs/logprobs.
|
||||
|
||||
Uses fixed placeholder values for tokens and logprobs. NOT suitable for training.
|
||||
"""
|
||||
|
||||
# Fixed dummy values
|
||||
DUMMY_TOKENS = [1, 2, 3]
|
||||
DUMMY_MASKED_TOKENS = [-100, 2, 3]
|
||||
DUMMY_LOGPROBS = [-0.5, -0.5, -0.5]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
self.server = server
|
||||
self.track_tree = track_tree
|
||||
# tokenizer is accepted but ignored - we don't tokenize anything
|
||||
|
||||
if track_tree:
|
||||
self.sequences: Dict[str, SequenceNode] = {}
|
||||
else:
|
||||
self.current_nodes: List[SequenceNode] = []
|
||||
|
||||
def _messages_to_text(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Convert messages to simple text format."""
|
||||
return "\n\n".join([f"{m['role']}:{m['content']}" for m in messages])
|
||||
|
||||
def _create_dummy_node(
|
||||
self,
|
||||
full_text: str,
|
||||
finish_reason: str = "stop",
|
||||
) -> SequenceNode:
|
||||
"""Create a sequence node with fixed dummy values."""
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
tokens=self.DUMMY_TOKENS,
|
||||
masked_tokens=self.DUMMY_MASKED_TOKENS,
|
||||
logprobs=self.DUMMY_LOGPROBS,
|
||||
metadata={"finish_reason": finish_reason, "dummy_tokens": True},
|
||||
)
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""Make a chat completion call and track with dummy tokens."""
|
||||
messages = kwargs.get("messages", [])
|
||||
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
|
||||
for choice in response.choices:
|
||||
completion_content = choice.message.content or ""
|
||||
# Append assistant response to messages for full_text
|
||||
all_messages = messages + [
|
||||
{"role": "assistant", "content": completion_content}
|
||||
]
|
||||
full_text = self._messages_to_text(all_messages)
|
||||
|
||||
node = self._create_dummy_node(
|
||||
full_text=full_text,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
)
|
||||
|
||||
if self.track_tree:
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
self.current_nodes.append(node)
|
||||
|
||||
return response
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""Make a completion call and track with dummy tokens."""
|
||||
prompt = kwargs.get("prompt", "")
|
||||
|
||||
response = await self.server.completion(**kwargs)
|
||||
|
||||
for choice in response.choices:
|
||||
completion_text = choice.text or ""
|
||||
full_text = f"{prompt}{completion_text}"
|
||||
|
||||
node = self._create_dummy_node(
|
||||
full_text=full_text,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
)
|
||||
|
||||
if self.track_tree:
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
self.current_nodes.append(node)
|
||||
|
||||
return response
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get the current state of tracked sequences."""
|
||||
if self.track_tree:
|
||||
return {
|
||||
"sequences": self.sequences.copy(),
|
||||
"tree": self.sequences.copy(),
|
||||
}
|
||||
else:
|
||||
return {"nodes": self.current_nodes.copy()}
|
||||
|
||||
def reset(self):
|
||||
"""Clear all tracked sequences."""
|
||||
if self.track_tree:
|
||||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ from openai.types.chat.chat_completion import ChatCompletion
|
|||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.managed_server import (
|
||||
DummyManagedServer,
|
||||
ManagedServer,
|
||||
)
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
|
|
@ -361,19 +364,28 @@ class ServerManager:
|
|||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[ManagedServer, None]:
|
||||
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
The ManagedServer wraps the most available server and tracks text sequences
|
||||
with aligned tokens and logprobs. State is automatically cleared on exit.
|
||||
|
||||
For OpenAI endpoints (which don't support token IDs/logprobs), a
|
||||
DummyManagedServer is returned if the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER
|
||||
environment variable is set. Otherwise, a NotImplementedError is raised.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
|
||||
Yields:
|
||||
ManagedServer instance wrapping the selected server
|
||||
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
|
||||
the selected server
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If using OpenAI server without the
|
||||
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server() as managed:
|
||||
|
|
@ -394,16 +406,41 @@ class ServerManager:
|
|||
most_available_server = i
|
||||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
if isinstance(self.servers[most_available_server], OpenAIServer):
|
||||
selected_server = self.servers[most_available_server]
|
||||
|
||||
# Handle OpenAI servers separately - they don't support token IDs/logprobs
|
||||
if isinstance(selected_server, OpenAIServer):
|
||||
allow_dummy = os.environ.get(
|
||||
"ATROPOS_ALLOW_DUMMY_MANAGED_SERVER", ""
|
||||
).lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
if not allow_dummy:
|
||||
raise NotImplementedError(
|
||||
"OpenAI endpoints do not support token IDs or logprobs required for "
|
||||
"ManagedServer. If you don't need actual token-level training data and "
|
||||
"are okay with dummy placeholder values, set the environment variable:\n\n"
|
||||
" export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1\n\n"
|
||||
"WARNING: The DummyManagedServer will return placeholder token IDs and "
|
||||
"logprobs (all zeros) that are NOT suitable for training. Use only for "
|
||||
"evaluation or testing workflows."
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"Using OpenAIServer with managed_server does not allow for state tracking"
|
||||
"Using DummyManagedServer with OpenAI endpoint. Token IDs and logprobs "
|
||||
"will be placeholder values and are NOT suitable for training."
|
||||
)
|
||||
yield self.servers[most_available_server]
|
||||
managed = DummyManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue