mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Merge conflict commit
This commit is contained in:
commit
f198c1738e
13 changed files with 579 additions and 14 deletions
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
## Overview
|
||||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
|
||||
|
||||
|
|
@ -36,7 +36,8 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|||
|
||||
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
|
||||
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences
|
||||
- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers
|
||||
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
|
||||
- ✅ **Branching Support**: Handles n>1 completions naturally
|
||||
- ✅ **Clean API**: Simple context manager pattern
|
||||
|
|
@ -609,6 +610,28 @@ Intercept completion call and track sequences.
|
|||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
|
||||
#### `async def get_logprobs(**kwargs) -> Dict[str, Any]`
|
||||
Fetch logprobs with a normalized schema that is backend-agnostic.
|
||||
|
||||
**Args (common):**
|
||||
- `messages` or `prompt` or `input_ids`
|
||||
- `n`: Number of sampled sequences
|
||||
- `max_tokens`
|
||||
- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop`
|
||||
|
||||
**Returns (normalized):**
|
||||
```python
|
||||
{
|
||||
"prompt_tokens": List[int],
|
||||
"prompt_topk_token_ids": List[List[int]], # [pos][k]
|
||||
"prompt_topk_logprobs": List[List[float]], # [pos][k]
|
||||
}
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Strict mode: backend must provide real prompt top-k arrays.
|
||||
- Missing keys should be treated as backend contract violations.
|
||||
|
||||
#### `def get_state() -> Dict[str, Any]`
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
|
|
@ -712,11 +735,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
|
|||
|
||||
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
|
||||
- Provides the same interface as `ManagedServer`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs:
|
||||
- `tokens`: `[1, 2, 3]`
|
||||
- `masked_tokens`: `[-100, 2, 3]`
|
||||
- `logprobs`: `[-0.5, -0.5, -0.5]`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays)
|
||||
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
|
||||
- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload)
|
||||
|
||||
### When to Use DummyManagedServer
|
||||
|
||||
|
|
@ -746,8 +767,11 @@ async with self.server.managed_server() as managed:
|
|||
# nodes contain placeholder token data - DO NOT use for training
|
||||
for node in nodes:
|
||||
print(node.full_text) # Real completion text
|
||||
print(node.tokens) # [1, 2, 3] - placeholder!
|
||||
print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder!
|
||||
print(node.tokens[:5]) # placeholder values
|
||||
print(node.logprobs[:5]) # placeholder values
|
||||
|
||||
# Strict mode: get_logprobs is not available on DummyManagedServer
|
||||
# and will raise NotImplementedError.
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
|
|
|||
|
|
@ -8,6 +8,16 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_
|
|||
|
||||
> **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details.
|
||||
|
||||
### Normalized `get_logprobs` API
|
||||
|
||||
`ManagedServer` and supported server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema:
|
||||
|
||||
- `prompt_tokens`
|
||||
- `prompt_topk_token_ids`
|
||||
- `prompt_topk_logprobs`
|
||||
|
||||
Backends are expected to return real prompt top-k arrays (`[pos][k]`) matching this schema.
|
||||
|
||||
## Tool Call Support
|
||||
|
||||
ManagedServer supports OpenAI-style tool calling via vLLM's tool parsers. Pass `tool_parser` at init:
|
||||
|
|
|
|||
|
|
@ -704,6 +704,39 @@ class ManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch prompt logprobs via wrapped server with a normalized schema.
|
||||
|
||||
Supported inputs:
|
||||
- prompt
|
||||
- messages (converted to prompt)
|
||||
- input_ids
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
request_kwargs = kwargs.copy()
|
||||
messages = request_kwargs.pop("messages", None)
|
||||
|
||||
if messages is not None:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
request_kwargs["prompt"] = prompt
|
||||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
"Strict mode requires backend prompt logprobs."
|
||||
)
|
||||
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
return payload
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
@ -815,6 +848,15 @@ class DummyManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Dummy managed server does not provide real prompt logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DummyManagedServer does not support get_logprobs in strict mode. "
|
||||
"Use a backend with real prompt logprob support."
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -421,6 +421,15 @@ class APIServer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper for prompt logprobs. Can be overridden by child classes.
|
||||
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -638,3 +647,72 @@ class APIServer(ABC):
|
|||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs eval.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Prompt-logprob API with strict normalized output schema.
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens: List[int]
|
||||
- prompt_topk_token_ids: List[List[int]]
|
||||
- prompt_topk_logprobs: List[List[float]]
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
if self.config.base_url is not None:
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {"attempts": 0}
|
||||
if split == "train":
|
||||
payload = await self._logprobs(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
payload = await self._logprobs_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return payload
|
||||
|
|
|
|||
|
|
@ -363,6 +363,32 @@ class ServerManager:
|
|||
**kwargs
|
||||
)
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> dict:
|
||||
"""
|
||||
Route normalized prompt-logprob requests to the most available server.
|
||||
|
||||
Returns a normalized dict with:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
) > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
return await self.servers[most_available_server].get_logprobs(**kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
|
||||
most_available_server = 0
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
|
|
@ -231,6 +232,129 @@ class VLLMServer(APIServer):
|
|||
finish_reasons_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_topk_entry(
|
||||
token_logprobs_entry: Any,
|
||||
) -> Tuple[List[int], List[float]]:
|
||||
"""
|
||||
Normalize a single token-position logprob payload into parallel top-k arrays.
|
||||
|
||||
Supports common structures from vLLM responses:
|
||||
- dict: {token_id: logprob, ...}
|
||||
- list[dict]: [{token_id: logprob}, ...]
|
||||
"""
|
||||
if isinstance(token_logprobs_entry, dict):
|
||||
items = list(token_logprobs_entry.items())
|
||||
return [int(k) for k, _ in items], [float(v) for _, v in items]
|
||||
|
||||
if isinstance(token_logprobs_entry, list):
|
||||
token_ids: List[int] = []
|
||||
logprobs: List[float] = []
|
||||
for item in token_logprobs_entry:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key, value in item.items():
|
||||
token_ids.append(int(key))
|
||||
logprobs.append(float(value))
|
||||
return token_ids, logprobs
|
||||
|
||||
return [], []
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
|
||||
|
||||
Args:
|
||||
top_k / top_logprobs: Optional number of logprobs per position.
|
||||
Defaults to 1.
|
||||
prompt or input_ids: Input text or token IDs.
|
||||
|
||||
Returns:
|
||||
Normalized dict:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for get_logprobs!"
|
||||
|
||||
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
|
||||
top_k = max(1, top_k)
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
from_prompt_text = False
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
from_prompt_text = True
|
||||
|
||||
# Only normalize BOS for tokenizer-encoded prompt text.
|
||||
if (
|
||||
from_prompt_text
|
||||
and len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
|
||||
if "max_completion_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
|
||||
kwargs.pop("model", None)
|
||||
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}}
|
||||
request_data["prompt_logprobs"] = top_k
|
||||
request_data.update(kwargs)
|
||||
# This API is prompt-logprobs focused, not generation-focused.
|
||||
request_data["n"] = 1
|
||||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
raw_prompt_logprobs = results.get("prompt_logprobs")
|
||||
if raw_prompt_logprobs is None:
|
||||
raise ValueError(
|
||||
"vLLM /generate response missing 'prompt_logprobs'. "
|
||||
"Ensure backend supports prompt logprobs."
|
||||
)
|
||||
|
||||
# Handle either direct [position] payloads or [sequence][position] payloads.
|
||||
if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list):
|
||||
prompt_entries = raw_prompt_logprobs[0]
|
||||
else:
|
||||
prompt_entries = raw_prompt_logprobs
|
||||
|
||||
prompt_topk_token_ids: List[List[int]] = []
|
||||
prompt_topk_logprobs: List[List[float]] = []
|
||||
for entry in prompt_entries:
|
||||
topk_ids, topk_lps = self._normalize_topk_entry(entry)
|
||||
prompt_topk_token_ids.append(topk_ids)
|
||||
prompt_topk_logprobs.append(topk_lps)
|
||||
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": prompt_topk_token_ids,
|
||||
"prompt_topk_logprobs": prompt_topk_logprobs,
|
||||
}
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue