mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge conflict commit
This commit is contained in:
commit
f198c1738e
13 changed files with 579 additions and 14 deletions
|
|
@ -9,12 +9,12 @@ repos:
|
|||
- id: check-merge-conflict
|
||||
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 26.1.0
|
||||
rev: 26.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.4
|
||||
rev: v0.15.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--select=I", "--fix"] # Only run import-related rules
|
||||
|
|
@ -32,7 +32,7 @@ repos:
|
|||
args: ['--baseline', '.secrets.baseline']
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1
|
||||
rev: v2.4.2
|
||||
hooks:
|
||||
- id: codespell
|
||||
args: ["--skip", "*.csv,*.html", "-L", "te,ans,sems,lsat,anc,strokin,lod,nam,ques,unparseable,rouge,oll,managin,expressio,re-declare"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
## Overview
|
||||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
|
||||
|
||||
|
|
@ -36,7 +36,8 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|||
|
||||
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
|
||||
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences
|
||||
- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers
|
||||
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
|
||||
- ✅ **Branching Support**: Handles n>1 completions naturally
|
||||
- ✅ **Clean API**: Simple context manager pattern
|
||||
|
|
@ -609,6 +610,28 @@ Intercept completion call and track sequences.
|
|||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
|
||||
#### `async def get_logprobs(**kwargs) -> Dict[str, Any]`
|
||||
Fetch logprobs with a normalized schema that is backend-agnostic.
|
||||
|
||||
**Args (common):**
|
||||
- `messages` or `prompt` or `input_ids`
|
||||
- `n`: Number of sampled sequences
|
||||
- `max_tokens`
|
||||
- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop`
|
||||
|
||||
**Returns (normalized):**
|
||||
```python
|
||||
{
|
||||
"prompt_tokens": List[int],
|
||||
"prompt_topk_token_ids": List[List[int]], # [pos][k]
|
||||
"prompt_topk_logprobs": List[List[float]], # [pos][k]
|
||||
}
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Strict mode: backend must provide real prompt top-k arrays.
|
||||
- Missing keys should be treated as backend contract violations.
|
||||
|
||||
#### `def get_state() -> Dict[str, Any]`
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
|
|
@ -712,11 +735,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
|
|||
|
||||
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
|
||||
- Provides the same interface as `ManagedServer`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs:
|
||||
- `tokens`: `[1, 2, 3]`
|
||||
- `masked_tokens`: `[-100, 2, 3]`
|
||||
- `logprobs`: `[-0.5, -0.5, -0.5]`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays)
|
||||
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
|
||||
- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload)
|
||||
|
||||
### When to Use DummyManagedServer
|
||||
|
||||
|
|
@ -746,8 +767,11 @@ async with self.server.managed_server() as managed:
|
|||
# nodes contain placeholder token data - DO NOT use for training
|
||||
for node in nodes:
|
||||
print(node.full_text) # Real completion text
|
||||
print(node.tokens) # [1, 2, 3] - placeholder!
|
||||
print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder!
|
||||
print(node.tokens[:5]) # placeholder values
|
||||
print(node.logprobs[:5]) # placeholder values
|
||||
|
||||
# Strict mode: get_logprobs is not available on DummyManagedServer
|
||||
# and will raise NotImplementedError.
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
|
|
|||
|
|
@ -8,6 +8,16 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_
|
|||
|
||||
> **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details.
|
||||
|
||||
### Normalized `get_logprobs` API
|
||||
|
||||
`ManagedServer` and supported server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema:
|
||||
|
||||
- `prompt_tokens`
|
||||
- `prompt_topk_token_ids`
|
||||
- `prompt_topk_logprobs`
|
||||
|
||||
Backends are expected to return real prompt top-k arrays (`[pos][k]`) matching this schema.
|
||||
|
||||
## Tool Call Support
|
||||
|
||||
ManagedServer supports OpenAI-style tool calling via vLLM's tool parsers. Pass `tool_parser` at init:
|
||||
|
|
|
|||
|
|
@ -704,6 +704,39 @@ class ManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch prompt logprobs via wrapped server with a normalized schema.
|
||||
|
||||
Supported inputs:
|
||||
- prompt
|
||||
- messages (converted to prompt)
|
||||
- input_ids
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
request_kwargs = kwargs.copy()
|
||||
messages = request_kwargs.pop("messages", None)
|
||||
|
||||
if messages is not None:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
request_kwargs["prompt"] = prompt
|
||||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
"Strict mode requires backend prompt logprobs."
|
||||
)
|
||||
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
return payload
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
@ -815,6 +848,15 @@ class DummyManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Dummy managed server does not provide real prompt logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DummyManagedServer does not support get_logprobs in strict mode. "
|
||||
"Use a backend with real prompt logprob support."
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -421,6 +421,15 @@ class APIServer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper for prompt logprobs. Can be overridden by child classes.
|
||||
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -638,3 +647,72 @@ class APIServer(ABC):
|
|||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs eval.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Prompt-logprob API with strict normalized output schema.
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens: List[int]
|
||||
- prompt_topk_token_ids: List[List[int]]
|
||||
- prompt_topk_logprobs: List[List[float]]
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
if self.config.base_url is not None:
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {"attempts": 0}
|
||||
if split == "train":
|
||||
payload = await self._logprobs(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
payload = await self._logprobs_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return payload
|
||||
|
|
|
|||
|
|
@ -363,6 +363,32 @@ class ServerManager:
|
|||
**kwargs
|
||||
)
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> dict:
|
||||
"""
|
||||
Route normalized prompt-logprob requests to the most available server.
|
||||
|
||||
Returns a normalized dict with:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
) > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
|
||||
return await self.servers[most_available_server].get_logprobs(**kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
|
||||
most_available_server = 0
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
|
|
@ -231,6 +232,129 @@ class VLLMServer(APIServer):
|
|||
finish_reasons_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_topk_entry(
|
||||
token_logprobs_entry: Any,
|
||||
) -> Tuple[List[int], List[float]]:
|
||||
"""
|
||||
Normalize a single token-position logprob payload into parallel top-k arrays.
|
||||
|
||||
Supports common structures from vLLM responses:
|
||||
- dict: {token_id: logprob, ...}
|
||||
- list[dict]: [{token_id: logprob}, ...]
|
||||
"""
|
||||
if isinstance(token_logprobs_entry, dict):
|
||||
items = list(token_logprobs_entry.items())
|
||||
return [int(k) for k, _ in items], [float(v) for _, v in items]
|
||||
|
||||
if isinstance(token_logprobs_entry, list):
|
||||
token_ids: List[int] = []
|
||||
logprobs: List[float] = []
|
||||
for item in token_logprobs_entry:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key, value in item.items():
|
||||
token_ids.append(int(key))
|
||||
logprobs.append(float(value))
|
||||
return token_ids, logprobs
|
||||
|
||||
return [], []
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
|
||||
|
||||
Args:
|
||||
top_k / top_logprobs: Optional number of logprobs per position.
|
||||
Defaults to 1.
|
||||
prompt or input_ids: Input text or token IDs.
|
||||
|
||||
Returns:
|
||||
Normalized dict:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for get_logprobs!"
|
||||
|
||||
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
|
||||
top_k = max(1, top_k)
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
from_prompt_text = False
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
from_prompt_text = True
|
||||
|
||||
# Only normalize BOS for tokenizer-encoded prompt text.
|
||||
if (
|
||||
from_prompt_text
|
||||
and len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
|
||||
if "max_completion_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
|
||||
kwargs.pop("model", None)
|
||||
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}}
|
||||
request_data["prompt_logprobs"] = top_k
|
||||
request_data.update(kwargs)
|
||||
# This API is prompt-logprobs focused, not generation-focused.
|
||||
request_data["n"] = 1
|
||||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
raw_prompt_logprobs = results.get("prompt_logprobs")
|
||||
if raw_prompt_logprobs is None:
|
||||
raise ValueError(
|
||||
"vLLM /generate response missing 'prompt_logprobs'. "
|
||||
"Ensure backend supports prompt logprobs."
|
||||
)
|
||||
|
||||
# Handle either direct [position] payloads or [sequence][position] payloads.
|
||||
if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list):
|
||||
prompt_entries = raw_prompt_logprobs[0]
|
||||
else:
|
||||
prompt_entries = raw_prompt_logprobs
|
||||
|
||||
prompt_topk_token_ids: List[List[int]] = []
|
||||
prompt_topk_logprobs: List[List[float]] = []
|
||||
for entry in prompt_entries:
|
||||
topk_ids, topk_lps = self._normalize_topk_entry(entry)
|
||||
prompt_topk_token_ids.append(topk_ids)
|
||||
prompt_topk_logprobs.append(topk_lps)
|
||||
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": prompt_topk_token_ids,
|
||||
"prompt_topk_logprobs": prompt_topk_logprobs,
|
||||
}
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
|
@ -16,7 +17,7 @@ def launch_api_for_testing(max_wait_for_api: int = 10) -> subprocess.Popen:
|
|||
# Use subprocess instead of multiprocessing to avoid inheriting pytest args
|
||||
api_proc = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
sys.executable,
|
||||
"-m",
|
||||
"atroposlib.cli.run_api",
|
||||
"--host",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import json
|
|||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
|
@ -27,7 +28,7 @@ def wait_for_api_server(max_wait=10):
|
|||
def api_server():
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
sys.executable,
|
||||
"-m",
|
||||
"atroposlib.cli.run_api",
|
||||
"--host",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Tests for API server message handling, particularly for SFT (Supervised Fine-Tun
|
|||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
|
@ -30,7 +31,7 @@ def api_server():
|
|||
# Start the API server as a subprocess
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
sys.executable,
|
||||
"-m",
|
||||
"atroposlib.cli.run_api",
|
||||
"--host",
|
||||
|
|
|
|||
|
|
@ -268,6 +268,91 @@ async def test_bos_token_handling(mock_server):
|
|||
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_normalized_schema(mock_server):
|
||||
"""ManagedServer.get_logprobs returns normalized prompt schema."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens]
|
||||
prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens]
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert kwargs.get("prompt") == prompt
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": prompt_topk_token_ids,
|
||||
"prompt_topk_logprobs": prompt_topk_logprobs,
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
|
||||
payload = await managed.get_logprobs(prompt=prompt, n=1)
|
||||
|
||||
assert payload["prompt_tokens"] == prompt_tokens
|
||||
assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids
|
||||
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_messages_passthrough(mock_server):
|
||||
"""ManagedServer.get_logprobs converts messages and passes prompt through."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
expected_prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(expected_prompt)
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert kwargs.get("prompt") == expected_prompt
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
payload = await managed.get_logprobs(messages=messages, top_k=1)
|
||||
|
||||
assert payload["prompt_tokens"] == prompt_tokens
|
||||
assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens)
|
||||
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_input_ids_only_passthrough(mock_server):
|
||||
"""ManagedServer.get_logprobs supports input_ids-only without requiring prompt."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
input_ids = [10, 20, 30]
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert "input_ids" in kwargs
|
||||
assert kwargs["input_ids"] == input_ids
|
||||
assert kwargs.get("prompt") is None
|
||||
return {
|
||||
"prompt_tokens": input_ids,
|
||||
"prompt_topk_token_ids": [[t] for t in input_ids],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in input_ids],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
payload = await managed.get_logprobs(input_ids=input_ids, top_k=1)
|
||||
|
||||
assert payload["prompt_tokens"] == input_ids
|
||||
assert payload["prompt_topk_token_ids"] == [[10], [20], [30]]
|
||||
assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
|
||||
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
with pytest.raises(NotImplementedError, match="does not implement get_logprobs"):
|
||||
await managed.get_logprobs(prompt=prompt, n=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_sequences(mock_server):
|
||||
"""Test that reset() clears all tracked sequences."""
|
||||
|
|
|
|||
105
atroposlib/tests/test_server_logprobs.py
Normal file
105
atroposlib/tests/test_server_logprobs.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Tests for get_logprobs wrappers and server-manager routing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
AsyncSemWithAdaptiveWeight,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
|
||||
class _FakeAPIServer(APIServer):
|
||||
def __init__(self, config: APIServerConfig):
|
||||
super().__init__(config=config, reasoning_config=None)
|
||||
self.calls = 0
|
||||
self.last_kwargs = None
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
self.server_healthy = True
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _completion_wrapper(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs):
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
prompt = kwargs.get("prompt", "")
|
||||
prompt_tokens = [ord(c) for c in prompt]
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
|
||||
}
|
||||
|
||||
|
||||
class _FakeRoutedServer:
|
||||
def __init__(
|
||||
self, name: str, train_slots: int, eval_slots: int, healthy: bool = True
|
||||
):
|
||||
self.name = name
|
||||
self.server_healthy = healthy
|
||||
self.sem = AsyncSemWithAdaptiveWeight(4)
|
||||
self.eval_sem = AsyncSemWithAdaptiveWeight(4)
|
||||
self.sem._value = train_slots
|
||||
self.eval_sem._value = eval_slots
|
||||
self.calls = 0
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
return {
|
||||
"server": self.name,
|
||||
"prompt_tokens": [1],
|
||||
"prompt_topk_token_ids": [[1]],
|
||||
"prompt_topk_logprobs": [[-0.1]],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apiserver_get_logprobs_train_eval_wrappers():
|
||||
cfg = APIServerConfig(
|
||||
model_name="test-model",
|
||||
base_url="",
|
||||
health_check=False,
|
||||
)
|
||||
server = _FakeAPIServer(cfg)
|
||||
|
||||
train_out = await server.get_logprobs(prompt="hi", split="train")
|
||||
assert train_out["prompt_tokens"] == [ord("h"), ord("i")]
|
||||
assert server.calls == 1
|
||||
assert server.last_kwargs["model"] == "test-model"
|
||||
assert len(server.request_timings) == 1
|
||||
assert len(server.attempts_list) == 1
|
||||
assert len(server.eval_request_timings) == 0
|
||||
assert len(server.eval_attempts_list) == 0
|
||||
|
||||
eval_out = await server.get_logprobs(prompt="ok", split="eval")
|
||||
assert eval_out["prompt_tokens"] == [ord("o"), ord("k")]
|
||||
assert server.calls == 2
|
||||
assert len(server.eval_request_timings) == 1
|
||||
assert len(server.eval_attempts_list) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_manager_get_logprobs_routes_to_most_available_server():
|
||||
s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True)
|
||||
s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True)
|
||||
s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False)
|
||||
|
||||
manager = ServerManager.__new__(ServerManager)
|
||||
manager.servers = [s1, s2, s3]
|
||||
|
||||
out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train")
|
||||
assert out_train["server"] == "s2"
|
||||
assert s2.calls == 1
|
||||
|
||||
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
|
||||
assert out_eval["server"] == "s1"
|
||||
assert s1.calls == 1
|
||||
68
atroposlib/tests/test_vllm_api_server_generate.py
Normal file
68
atroposlib/tests/test_vllm_api_server_generate.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Optional integration test for example_trainer.vllm_api_server /generate."""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_api_server_generate_endpoint_optional():
|
||||
"""
|
||||
Validate /generate contract on the custom vLLM API server.
|
||||
|
||||
This test only runs when vLLM is installed.
|
||||
"""
|
||||
pytest.importorskip("vllm")
|
||||
|
||||
module = import_module("example_trainer.vllm_api_server")
|
||||
|
||||
class _FakeLogprob:
|
||||
def __init__(self, value: float):
|
||||
self.logprob = value
|
||||
|
||||
class _FakeOutput:
|
||||
def __init__(self):
|
||||
self.text = " world"
|
||||
self.finish_reason = "stop"
|
||||
self.logprobs = [{11: _FakeLogprob(-0.3)}]
|
||||
self.token_ids = [11]
|
||||
|
||||
class _FakeRequestOutput:
|
||||
def __init__(self):
|
||||
self.prompt = "hello"
|
||||
self.prompt_token_ids = [1, 2]
|
||||
self.outputs = [_FakeOutput()]
|
||||
|
||||
class _FakeEngine:
|
||||
tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})()
|
||||
|
||||
def generate(self, *_args, **_kwargs):
|
||||
async def _gen():
|
||||
yield _FakeRequestOutput()
|
||||
|
||||
return _gen()
|
||||
|
||||
old_engine = module.engine
|
||||
module.engine = _FakeEngine()
|
||||
try:
|
||||
client = TestClient(module.app)
|
||||
resp = client.post(
|
||||
"/generate",
|
||||
json={
|
||||
"prompt": "hello",
|
||||
"max_tokens": 1,
|
||||
"temperature": 0.0,
|
||||
"logprobs": 1,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "text" in body and body["text"] == [" world"]
|
||||
assert body["prompt"] == "hello"
|
||||
assert body["finish_reasons"] == ["stop"]
|
||||
assert "logprobs" in body
|
||||
assert "token_ids" in body
|
||||
assert "prompt_token_ids" in body
|
||||
finally:
|
||||
module.engine = old_engine
|
||||
Loading…
Add table
Add a link
Reference in a new issue