Merge pull request #406 from NousResearch/logprobsfn

Unified get_logprobs interface across the server stack
This commit is contained in:
J-SUPHA 2026-03-05 17:36:22 -05:00 committed by GitHub
commit 1f676f2185
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 570 additions and 8 deletions

View file

@ -2,7 +2,7 @@
## Overview
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
@ -36,7 +36,8 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly
- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences
- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
- ✅ **Branching Support**: Handles n>1 completions naturally
- ✅ **Clean API**: Simple context manager pattern
@ -609,6 +610,28 @@ Intercept completion call and track sequences.
**Side Effects:**
- Tracks sequences in internal storage
#### `async def get_logprobs(**kwargs) -> Dict[str, Any]`
Fetch logprobs with a normalized schema that is backend-agnostic.
**Args (common):**
- `messages` or `prompt` or `input_ids`
- `n`: Number of sampled sequences
- `max_tokens`
- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop`
**Returns (normalized):**
```python
{
"prompt_tokens": List[int],
"prompt_topk_token_ids": List[List[int]], # [pos][k]
"prompt_topk_logprobs": List[List[float]], # [pos][k]
}
```
**Notes:**
- Strict mode: backend must provide real prompt top-k arrays.
- Missing keys should be treated as backend contract violations.
#### `def get_state() -> Dict[str, Any]`
Get the current state of tracked sequences.
@ -712,11 +735,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
- Provides the same interface as `ManagedServer`
- Returns **fixed placeholder values** for tokens and logprobs:
- `tokens`: `[1, 2, 3]`
- `masked_tokens`: `[-100, 2, 3]`
- `logprobs`: `[-0.5, -0.5, -0.5]`
- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays)
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload)
### When to Use DummyManagedServer
@ -746,8 +767,11 @@ async with self.server.managed_server() as managed:
# nodes contain placeholder token data - DO NOT use for training
for node in nodes:
print(node.full_text) # Real completion text
print(node.tokens) # [1, 2, 3] - placeholder!
print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder!
print(node.tokens[:5]) # placeholder values
print(node.logprobs[:5]) # placeholder values
# Strict mode: get_logprobs is not available on DummyManagedServer
# and will raise NotImplementedError.
```
### Recommendation

View file

@ -8,6 +8,16 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_
> **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details.
### Normalized `get_logprobs` API
`ManagedServer` and supported server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema:
- `prompt_tokens`
- `prompt_topk_token_ids`
- `prompt_topk_logprobs`
Backends are expected to return real prompt top-k arrays (`[pos][k]`) matching this schema.
## Reasoning Model Support
The `ReasoningConfig` class enables support for reasoning/thinking models across different providers.

View file

@ -529,6 +529,39 @@ class ManagedServer:
else:
self.current_nodes.clear()
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Fetch prompt logprobs via wrapped server with a normalized schema.
Supported inputs:
- prompt
- messages (converted to prompt)
- input_ids
Returns:
Dict with:
- prompt_tokens
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
request_kwargs = kwargs.copy()
messages = request_kwargs.pop("messages", None)
if messages is not None:
prompt = self._convert_messages_to_prompt(messages)
request_kwargs["prompt"] = prompt
else:
prompt = request_kwargs.get("prompt")
if not hasattr(self.server, "get_logprobs"):
raise NotImplementedError(
f"{self.server.__class__.__name__} does not implement get_logprobs. "
"Strict mode requires backend prompt logprobs."
)
payload = await self.server.get_logprobs(**request_kwargs)
return payload
class DummyManagedServer:
"""
@ -640,6 +673,15 @@ class DummyManagedServer:
else:
self.current_nodes.clear()
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Dummy managed server does not provide real prompt logprobs.
"""
raise NotImplementedError(
"DummyManagedServer does not support get_logprobs in strict mode. "
"Use a backend with real prompt logprob support."
)
class ManagedServerAdapter:
"""

View file

@ -421,6 +421,15 @@ class APIServer(ABC):
"""
pass
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Wrapper for prompt logprobs. Can be overridden by child classes.
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
)
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -638,3 +647,72 @@ class APIServer(ABC):
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs eval.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Prompt-logprob API with strict normalized output schema.
Returns:
Dict with:
- prompt_tokens: List[int]
- prompt_topk_token_ids: List[List[int]]
- prompt_topk_logprobs: List[List[float]]
"""
if not self.initialized:
if self.config.health_check:
if self.config.base_url is not None:
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
else:
self.server_healthy = True
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {"attempts": 0}
if split == "train":
payload = await self._logprobs(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
payload = await self._logprobs_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return payload

View file

@ -346,6 +346,32 @@ class ServerManager:
**kwargs
)
async def get_logprobs(self, **kwargs) -> dict:
"""
Route normalized prompt-logprob requests to the most available server.
Returns a normalized dict with:
- prompt_tokens
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].get_logprobs(**kwargs)
@asynccontextmanager
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
most_available_server = 0

View file

@ -3,6 +3,7 @@
import asyncio
import warnings
from typing import Any, Dict, List, Tuple
import aiohttp
import openai
@ -231,6 +232,129 @@ class VLLMServer(APIServer):
finish_reasons_list,
)
@staticmethod
def _normalize_topk_entry(
token_logprobs_entry: Any,
) -> Tuple[List[int], List[float]]:
"""
Normalize a single token-position logprob payload into parallel top-k arrays.
Supports common structures from vLLM responses:
- dict: {token_id: logprob, ...}
- list[dict]: [{token_id: logprob}, ...]
"""
if isinstance(token_logprobs_entry, dict):
items = list(token_logprobs_entry.items())
return [int(k) for k, _ in items], [float(v) for _, v in items]
if isinstance(token_logprobs_entry, list):
token_ids: List[int] = []
logprobs: List[float] = []
for item in token_logprobs_entry:
if not isinstance(item, dict):
continue
for key, value in item.items():
token_ids.append(int(key))
logprobs.append(float(value))
return token_ids, logprobs
return [], []
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
Args:
top_k / top_logprobs: Optional number of logprobs per position.
Defaults to 1.
prompt or input_ids: Input text or token IDs.
Returns:
Normalized dict:
- prompt_tokens
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
assert (
kwargs.get("prompt", None) is not None
or kwargs.get("input_ids", None) is not None
), "Prompt or input_ids is required for get_logprobs!"
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
top_k = max(1, top_k)
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
from_prompt_text = False
if "input_ids" in kwargs:
prompt_tokens = kwargs.pop("input_ids")
kwargs.pop("prompt", None)
else:
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
from_prompt_text = True
# Only normalize BOS for tokenizer-encoded prompt text.
if (
from_prompt_text
and len(prompt_tokens) >= 2
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
):
prompt_tokens = prompt_tokens[1:]
if "max_new_tokens" in kwargs:
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
if "max_completion_tokens" in kwargs:
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
kwargs.pop("model", None)
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}}
request_data["prompt_logprobs"] = top_k
request_data.update(kwargs)
# This API is prompt-logprobs focused, not generation-focused.
request_data["n"] = 1
request_data["temperature"] = 0.0
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None:
raise ValueError(
"vLLM /generate response missing 'prompt_logprobs'. "
"Ensure backend supports prompt logprobs."
)
# Handle either direct [position] payloads or [sequence][position] payloads.
if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list):
prompt_entries = raw_prompt_logprobs[0]
else:
prompt_entries = raw_prompt_logprobs
prompt_topk_token_ids: List[List[int]] = []
prompt_topk_logprobs: List[List[float]] = []
for entry in prompt_entries:
topk_ids, topk_lps = self._normalize_topk_entry(entry)
prompt_topk_token_ids.append(topk_ids)
prompt_topk_logprobs.append(topk_lps)
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}
def resolve_openai_configs(
default_server_configs,

View file

@ -265,6 +265,91 @@ async def test_bos_token_handling(mock_server):
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
@pytest.mark.asyncio
async def test_get_logprobs_normalized_schema(mock_server):
"""ManagedServer.get_logprobs returns normalized prompt schema."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
prompt_tokens = mock_server.tokenizer.encode(prompt)
prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens]
prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens]
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(prompt=prompt, n=1)
assert payload["prompt_tokens"] == prompt_tokens
assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
@pytest.mark.asyncio
async def test_get_logprobs_messages_passthrough(mock_server):
"""ManagedServer.get_logprobs converts messages and passes prompt through."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
messages = [{"role": "user", "content": "Hello"}]
expected_prompt = managed._convert_messages_to_prompt(messages)
prompt_tokens = mock_server.tokenizer.encode(expected_prompt)
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == expected_prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(messages=messages, top_k=1)
assert payload["prompt_tokens"] == prompt_tokens
assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens)
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
@pytest.mark.asyncio
async def test_get_logprobs_input_ids_only_passthrough(mock_server):
"""ManagedServer.get_logprobs supports input_ids-only without requiring prompt."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
input_ids = [10, 20, 30]
async def _mock_get_logprobs(**kwargs):
assert "input_ids" in kwargs
assert kwargs["input_ids"] == input_ids
assert kwargs.get("prompt") is None
return {
"prompt_tokens": input_ids,
"prompt_topk_token_ids": [[t] for t in input_ids],
"prompt_topk_logprobs": [[-0.1] for _ in input_ids],
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(input_ids=input_ids, top_k=1)
assert payload["prompt_tokens"] == input_ids
assert payload["prompt_topk_token_ids"] == [[10], [20], [30]]
assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]]
@pytest.mark.asyncio
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
with pytest.raises(NotImplementedError, match="does not implement get_logprobs"):
await managed.get_logprobs(prompt=prompt, n=1)
@pytest.mark.asyncio
async def test_reset_clears_sequences(mock_server):
"""Test that reset() clears all tracked sequences."""

View file

@ -0,0 +1,105 @@
"""Tests for get_logprobs wrappers and server-manager routing."""
import pytest
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
AsyncSemWithAdaptiveWeight,
)
from atroposlib.envs.server_handling.server_manager import ServerManager
class _FakeAPIServer(APIServer):
def __init__(self, config: APIServerConfig):
super().__init__(config=config, reasoning_config=None)
self.calls = 0
self.last_kwargs = None
async def check_server_status_task(self, chat_completion: bool = True):
self.server_healthy = True
async def _chat_completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _tokens_and_logprobs_completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _get_logprobs_wrapper(self, **kwargs):
self.calls += 1
self.last_kwargs = kwargs
prompt = kwargs.get("prompt", "")
prompt_tokens = [ord(c) for c in prompt]
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
}
class _FakeRoutedServer:
def __init__(
self, name: str, train_slots: int, eval_slots: int, healthy: bool = True
):
self.name = name
self.server_healthy = healthy
self.sem = AsyncSemWithAdaptiveWeight(4)
self.eval_sem = AsyncSemWithAdaptiveWeight(4)
self.sem._value = train_slots
self.eval_sem._value = eval_slots
self.calls = 0
async def get_logprobs(self, **kwargs):
self.calls += 1
return {
"server": self.name,
"prompt_tokens": [1],
"prompt_topk_token_ids": [[1]],
"prompt_topk_logprobs": [[-0.1]],
}
@pytest.mark.asyncio
async def test_apiserver_get_logprobs_train_eval_wrappers():
cfg = APIServerConfig(
model_name="test-model",
base_url="",
health_check=False,
)
server = _FakeAPIServer(cfg)
train_out = await server.get_logprobs(prompt="hi", split="train")
assert train_out["prompt_tokens"] == [ord("h"), ord("i")]
assert server.calls == 1
assert server.last_kwargs["model"] == "test-model"
assert len(server.request_timings) == 1
assert len(server.attempts_list) == 1
assert len(server.eval_request_timings) == 0
assert len(server.eval_attempts_list) == 0
eval_out = await server.get_logprobs(prompt="ok", split="eval")
assert eval_out["prompt_tokens"] == [ord("o"), ord("k")]
assert server.calls == 2
assert len(server.eval_request_timings) == 1
assert len(server.eval_attempts_list) == 1
@pytest.mark.asyncio
async def test_server_manager_get_logprobs_routes_to_most_available_server():
s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True)
s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True)
s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False)
manager = ServerManager.__new__(ServerManager)
manager.servers = [s1, s2, s3]
out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train")
assert out_train["server"] == "s2"
assert s2.calls == 1
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
assert out_eval["server"] == "s1"
assert s1.calls == 1

View file

@ -0,0 +1,68 @@
"""Optional integration test for example_trainer.vllm_api_server /generate."""
from importlib import import_module
import pytest
from fastapi.testclient import TestClient
@pytest.mark.asyncio
async def test_vllm_api_server_generate_endpoint_optional():
"""
Validate /generate contract on the custom vLLM API server.
This test only runs when vLLM is installed.
"""
pytest.importorskip("vllm")
module = import_module("example_trainer.vllm_api_server")
class _FakeLogprob:
def __init__(self, value: float):
self.logprob = value
class _FakeOutput:
def __init__(self):
self.text = " world"
self.finish_reason = "stop"
self.logprobs = [{11: _FakeLogprob(-0.3)}]
self.token_ids = [11]
class _FakeRequestOutput:
def __init__(self):
self.prompt = "hello"
self.prompt_token_ids = [1, 2]
self.outputs = [_FakeOutput()]
class _FakeEngine:
tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})()
def generate(self, *_args, **_kwargs):
async def _gen():
yield _FakeRequestOutput()
return _gen()
old_engine = module.engine
module.engine = _FakeEngine()
try:
client = TestClient(module.app)
resp = client.post(
"/generate",
json={
"prompt": "hello",
"max_tokens": 1,
"temperature": 0.0,
"logprobs": 1,
},
)
assert resp.status_code == 200
body = resp.json()
assert "text" in body and body["text"] == [" world"]
assert body["prompt"] == "hello"
assert body["finish_reasons"] == ["stop"]
assert "logprobs" in body
assert "token_ids" in body
assert "prompt_token_ids" in body
finally:
module.engine = old_engine