diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24c59eef..10fee35f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,12 @@ repos: - id: check-merge-conflict - repo: https://github.com/psf/black-pre-commit-mirror - rev: 26.1.0 + rev: 26.3.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.4 + rev: v0.15.5 hooks: - id: ruff args: ["--select=I", "--fix"] # Only run import-related rules @@ -32,7 +32,7 @@ repos: args: ['--baseline', '.secrets.baseline'] - repo: https://github.com/codespell-project/codespell - rev: v2.4.1 + rev: v2.4.2 hooks: - id: codespell args: ["--skip", "*.csv,*.html", "-L", "te,ans,sems,lsat,anc,strokin,lod,nam,ques,unparseable,rouge,oll,managin,expressio,re-declare"] diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 692f3a10..9b1c9c3e 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -2,7 +2,7 @@ ## Overview -`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. +`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. **Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection. @@ -36,7 +36,8 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions - ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0 -- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly +- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences +- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers - ✅ **Multi-turn Support**: Automatically handles conversation extensions - ✅ **Branching Support**: Handles n>1 completions naturally - ✅ **Clean API**: Simple context manager pattern @@ -609,6 +610,28 @@ Intercept completion call and track sequences. **Side Effects:** - Tracks sequences in internal storage +#### `async def get_logprobs(**kwargs) -> Dict[str, Any]` +Fetch logprobs with a normalized schema that is backend-agnostic. + +**Args (common):** +- `messages` or `prompt` or `input_ids` +- `n`: Number of sampled sequences +- `max_tokens` +- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop` + +**Returns (normalized):** +```python +{ + "prompt_tokens": List[int], + "prompt_topk_token_ids": List[List[int]], # [pos][k] + "prompt_topk_logprobs": List[List[float]], # [pos][k] +} +``` + +**Notes:** +- Strict mode: backend must provide real prompt top-k arrays. +- Missing keys should be treated as backend contract violations. + #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. @@ -712,11 +735,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1 With this flag set, `managed_server()` will return a `DummyManagedServer` that: - Provides the same interface as `ManagedServer` -- Returns **fixed placeholder values** for tokens and logprobs: - - `tokens`: `[1, 2, 3]` - - `masked_tokens`: `[-100, 2, 3]` - - `logprobs`: `[-0.5, -0.5, -0.5]` +- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays) - Uses simple text formatting for `full_text`: `role:content` joined by `\n\n` +- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload) ### When to Use DummyManagedServer @@ -746,8 +767,11 @@ async with self.server.managed_server() as managed: # nodes contain placeholder token data - DO NOT use for training for node in nodes: print(node.full_text) # Real completion text - print(node.tokens) # [1, 2, 3] - placeholder! - print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder! + print(node.tokens[:5]) # placeholder values + print(node.logprobs[:5]) # placeholder values + + # Strict mode: get_logprobs is not available on DummyManagedServer + # and will raise NotImplementedError. ``` ### Recommendation diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index aff3b096..a2c053fa 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -8,6 +8,16 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ > **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details. +### Normalized `get_logprobs` API + +`ManagedServer` and supported server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema: + +- `prompt_tokens` +- `prompt_topk_token_ids` +- `prompt_topk_logprobs` + +Backends are expected to return real prompt top-k arrays (`[pos][k]`) matching this schema. + ## Tool Call Support ManagedServer supports OpenAI-style tool calling via vLLM's tool parsers. Pass `tool_parser` at init: diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 506f5bf4..9d46f265 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -704,6 +704,39 @@ class ManagedServer: else: self.current_nodes.clear() + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Fetch prompt logprobs via wrapped server with a normalized schema. + + Supported inputs: + - prompt + - messages (converted to prompt) + - input_ids + + Returns: + Dict with: + - prompt_tokens + - prompt_topk_token_ids + - prompt_topk_logprobs + """ + request_kwargs = kwargs.copy() + messages = request_kwargs.pop("messages", None) + + if messages is not None: + prompt = self._convert_messages_to_prompt(messages) + request_kwargs["prompt"] = prompt + else: + prompt = request_kwargs.get("prompt") + + if not hasattr(self.server, "get_logprobs"): + raise NotImplementedError( + f"{self.server.__class__.__name__} does not implement get_logprobs. " + "Strict mode requires backend prompt logprobs." + ) + + payload = await self.server.get_logprobs(**request_kwargs) + return payload + class DummyManagedServer: """ @@ -815,6 +848,15 @@ class DummyManagedServer: else: self.current_nodes.clear() + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Dummy managed server does not provide real prompt logprobs. + """ + raise NotImplementedError( + "DummyManagedServer does not support get_logprobs in strict mode. " + "Use a backend with real prompt logprob support." + ) + class ManagedServerAdapter: """ diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 3040c9ca..ade862dd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -421,6 +421,15 @@ class APIServer(ABC): """ pass + async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]: + """ + Wrapper for prompt logprobs. Can be overridden by child classes. + Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement _get_logprobs_wrapper." + ) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -638,3 +647,72 @@ class APIServer(ABC): self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]: + """ + Simple retry and stat collection wrapper for get_logprobs. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + payload = await self._get_logprobs_wrapper(**kwargs) + stat_dict["end"] = time.time() + return payload + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]: + """ + Simple retry and stat collection wrapper for get_logprobs eval. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + payload = await self._get_logprobs_wrapper(**kwargs) + stat_dict["end"] = time.time() + return payload + + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Prompt-logprob API with strict normalized output schema. + + Returns: + Dict with: + - prompt_tokens: List[int] + - prompt_topk_token_ids: List[List[int]] + - prompt_topk_logprobs: List[List[float]] + """ + if not self.initialized: + if self.config.health_check: + if self.config.base_url is not None: + self.check_task = asyncio.create_task( + self.check_server_status_task(chat_completion=False) + ) + else: + self.server_healthy = True + else: + self.server_healthy = True + self.initialized = True + + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {"attempts": 0} + if split == "train": + payload = await self._logprobs(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + payload = await self._logprobs_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return payload diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 9d2ca48b..b9c493f9 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -363,6 +363,32 @@ class ServerManager: **kwargs ) + async def get_logprobs(self, **kwargs) -> dict: + """ + Route normalized prompt-logprob requests to the most available server. + + Returns a normalized dict with: + - prompt_tokens + - prompt_topk_token_ids + - prompt_topk_logprobs + """ + is_train = kwargs.pop("split", "train") == "train" + most_available_server = 0 + most_available_server_num_slots = -1 + await self.wait_for_sem(is_train) + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if ( + server.sem._value if is_train else server.eval_sem._value + ) > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = ( + server.sem._value if is_train else server.eval_sem._value + ) + + return await self.servers[most_available_server].get_logprobs(**kwargs) + @asynccontextmanager async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]: most_available_server = 0 diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 96242754..3c35bebb 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -3,6 +3,7 @@ import asyncio import warnings +from typing import Any, Dict, List, Tuple import aiohttp import openai @@ -231,6 +232,129 @@ class VLLMServer(APIServer): finish_reasons_list, ) + @staticmethod + def _normalize_topk_entry( + token_logprobs_entry: Any, + ) -> Tuple[List[int], List[float]]: + """ + Normalize a single token-position logprob payload into parallel top-k arrays. + + Supports common structures from vLLM responses: + - dict: {token_id: logprob, ...} + - list[dict]: [{token_id: logprob}, ...] + """ + if isinstance(token_logprobs_entry, dict): + items = list(token_logprobs_entry.items()) + return [int(k) for k, _ in items], [float(v) for _, v in items] + + if isinstance(token_logprobs_entry, list): + token_ids: List[int] = [] + logprobs: List[float] = [] + for item in token_logprobs_entry: + if not isinstance(item, dict): + continue + for key, value in item.items(): + token_ids.append(int(key)) + logprobs.append(float(value)) + return token_ids, logprobs + + return [], [] + + async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]: + """ + Fetch normalized prompt logprobs from vLLM /generate with optional top-k. + + Args: + top_k / top_logprobs: Optional number of logprobs per position. + Defaults to 1. + prompt or input_ids: Input text or token IDs. + + Returns: + Normalized dict: + - prompt_tokens + - prompt_topk_token_ids + - prompt_topk_logprobs + """ + assert ( + kwargs.get("prompt", None) is not None + or kwargs.get("input_ids", None) is not None + ), "Prompt or input_ids is required for get_logprobs!" + + top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) + top_k = max(1, top_k) + + # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt + from_prompt_text = False + if "input_ids" in kwargs: + prompt_tokens = kwargs.pop("input_ids") + kwargs.pop("prompt", None) + else: + prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + from_prompt_text = True + + # Only normalize BOS for tokenizer-encoded prompt text. + if ( + from_prompt_text + and len(prompt_tokens) >= 2 + and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] + ): + prompt_tokens = prompt_tokens[1:] + + if "max_new_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_new_tokens") + if "max_completion_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") + kwargs.pop("model", None) + + request_data = {"prompt": {"prompt_token_ids": prompt_tokens}} + request_data["prompt_logprobs"] = top_k + request_data.update(kwargs) + # This API is prompt-logprobs focused, not generation-focused. + request_data["n"] = 1 + request_data["temperature"] = 0.0 + request_data["top_p"] = 1.0 + request_data.setdefault("max_tokens", 1) + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=request_data, + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + ) as response: + response.raise_for_status() + results = await response.json() + + raw_prompt_logprobs = results.get("prompt_logprobs") + if raw_prompt_logprobs is None: + raise ValueError( + "vLLM /generate response missing 'prompt_logprobs'. " + "Ensure backend supports prompt logprobs." + ) + + # Handle either direct [position] payloads or [sequence][position] payloads. + if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list): + prompt_entries = raw_prompt_logprobs[0] + else: + prompt_entries = raw_prompt_logprobs + + prompt_topk_token_ids: List[List[int]] = [] + prompt_topk_logprobs: List[List[float]] = [] + for entry in prompt_entries: + topk_ids, topk_lps = self._normalize_topk_entry(entry) + prompt_topk_token_ids.append(topk_ids) + prompt_topk_logprobs.append(topk_lps) + + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, + } + def resolve_openai_configs( default_server_configs, diff --git a/atroposlib/tests/api_test_utils.py b/atroposlib/tests/api_test_utils.py index c88d51f2..200ada8c 100644 --- a/atroposlib/tests/api_test_utils.py +++ b/atroposlib/tests/api_test_utils.py @@ -1,4 +1,5 @@ import subprocess +import sys import time import requests @@ -16,7 +17,7 @@ def launch_api_for_testing(max_wait_for_api: int = 10) -> subprocess.Popen: # Use subprocess instead of multiprocessing to avoid inheriting pytest args api_proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_compression.py b/atroposlib/tests/test_api_compression.py index 25cbdfe0..0202c3ad 100644 --- a/atroposlib/tests/test_api_compression.py +++ b/atroposlib/tests/test_api_compression.py @@ -5,6 +5,7 @@ import json import os import signal import subprocess +import sys import time import pytest @@ -27,7 +28,7 @@ def wait_for_api_server(max_wait=10): def api_server(): proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_messages_handling.py b/atroposlib/tests/test_api_messages_handling.py index 0f7ac922..228be3ea 100644 --- a/atroposlib/tests/test_api_messages_handling.py +++ b/atroposlib/tests/test_api_messages_handling.py @@ -5,6 +5,7 @@ Tests for API server message handling, particularly for SFT (Supervised Fine-Tun import os import signal import subprocess +import sys import time import pytest @@ -30,7 +31,7 @@ def api_server(): # Start the API server as a subprocess proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 64f05716..1524aaf7 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -268,6 +268,91 @@ async def test_bos_token_handling(mock_server): assert mock_server.tokenizer.bos_token_id not in node.tokens[1:] +@pytest.mark.asyncio +async def test_get_logprobs_normalized_schema(mock_server): + """ManagedServer.get_logprobs returns normalized prompt schema.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens] + prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens] + + async def _mock_get_logprobs(**kwargs): + assert kwargs.get("prompt") == prompt + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, + } + + mock_server.get_logprobs = _mock_get_logprobs + + payload = await managed.get_logprobs(prompt=prompt, n=1) + + assert payload["prompt_tokens"] == prompt_tokens + assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids + assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs + + +@pytest.mark.asyncio +async def test_get_logprobs_messages_passthrough(mock_server): + """ManagedServer.get_logprobs converts messages and passes prompt through.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + messages = [{"role": "user", "content": "Hello"}] + expected_prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(expected_prompt) + + async def _mock_get_logprobs(**kwargs): + assert kwargs.get("prompt") == expected_prompt + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[t] for t in prompt_tokens], + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens], + } + + mock_server.get_logprobs = _mock_get_logprobs + payload = await managed.get_logprobs(messages=messages, top_k=1) + + assert payload["prompt_tokens"] == prompt_tokens + assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens) + assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens) + + +@pytest.mark.asyncio +async def test_get_logprobs_input_ids_only_passthrough(mock_server): + """ManagedServer.get_logprobs supports input_ids-only without requiring prompt.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + input_ids = [10, 20, 30] + + async def _mock_get_logprobs(**kwargs): + assert "input_ids" in kwargs + assert kwargs["input_ids"] == input_ids + assert kwargs.get("prompt") is None + return { + "prompt_tokens": input_ids, + "prompt_topk_token_ids": [[t] for t in input_ids], + "prompt_topk_logprobs": [[-0.1] for _ in input_ids], + } + + mock_server.get_logprobs = _mock_get_logprobs + payload = await managed.get_logprobs(input_ids=input_ids, top_k=1) + + assert payload["prompt_tokens"] == input_ids + assert payload["prompt_topk_token_ids"] == [[10], [20], [30]] + assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]] + + +@pytest.mark.asyncio +async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): + """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + with pytest.raises(NotImplementedError, match="does not implement get_logprobs"): + await managed.get_logprobs(prompt=prompt, n=1) + + @pytest.mark.asyncio async def test_reset_clears_sequences(mock_server): """Test that reset() clears all tracked sequences.""" diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py new file mode 100644 index 00000000..8cbd84ad --- /dev/null +++ b/atroposlib/tests/test_server_logprobs.py @@ -0,0 +1,105 @@ +"""Tests for get_logprobs wrappers and server-manager routing.""" + +import pytest + +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + AsyncSemWithAdaptiveWeight, +) +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class _FakeAPIServer(APIServer): + def __init__(self, config: APIServerConfig): + super().__init__(config=config, reasoning_config=None) + self.calls = 0 + self.last_kwargs = None + + async def check_server_status_task(self, chat_completion: bool = True): + self.server_healthy = True + + async def _chat_completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _tokens_and_logprobs_completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _get_logprobs_wrapper(self, **kwargs): + self.calls += 1 + self.last_kwargs = kwargs + prompt = kwargs.get("prompt", "") + prompt_tokens = [ord(c) for c in prompt] + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[t] for t in prompt_tokens], + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens], + } + + +class _FakeRoutedServer: + def __init__( + self, name: str, train_slots: int, eval_slots: int, healthy: bool = True + ): + self.name = name + self.server_healthy = healthy + self.sem = AsyncSemWithAdaptiveWeight(4) + self.eval_sem = AsyncSemWithAdaptiveWeight(4) + self.sem._value = train_slots + self.eval_sem._value = eval_slots + self.calls = 0 + + async def get_logprobs(self, **kwargs): + self.calls += 1 + return { + "server": self.name, + "prompt_tokens": [1], + "prompt_topk_token_ids": [[1]], + "prompt_topk_logprobs": [[-0.1]], + } + + +@pytest.mark.asyncio +async def test_apiserver_get_logprobs_train_eval_wrappers(): + cfg = APIServerConfig( + model_name="test-model", + base_url="", + health_check=False, + ) + server = _FakeAPIServer(cfg) + + train_out = await server.get_logprobs(prompt="hi", split="train") + assert train_out["prompt_tokens"] == [ord("h"), ord("i")] + assert server.calls == 1 + assert server.last_kwargs["model"] == "test-model" + assert len(server.request_timings) == 1 + assert len(server.attempts_list) == 1 + assert len(server.eval_request_timings) == 0 + assert len(server.eval_attempts_list) == 0 + + eval_out = await server.get_logprobs(prompt="ok", split="eval") + assert eval_out["prompt_tokens"] == [ord("o"), ord("k")] + assert server.calls == 2 + assert len(server.eval_request_timings) == 1 + assert len(server.eval_attempts_list) == 1 + + +@pytest.mark.asyncio +async def test_server_manager_get_logprobs_routes_to_most_available_server(): + s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True) + s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True) + s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False) + + manager = ServerManager.__new__(ServerManager) + manager.servers = [s1, s2, s3] + + out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train") + assert out_train["server"] == "s2" + assert s2.calls == 1 + + out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") + assert out_eval["server"] == "s1" + assert s1.calls == 1 diff --git a/atroposlib/tests/test_vllm_api_server_generate.py b/atroposlib/tests/test_vllm_api_server_generate.py new file mode 100644 index 00000000..f2c9c48c --- /dev/null +++ b/atroposlib/tests/test_vllm_api_server_generate.py @@ -0,0 +1,68 @@ +"""Optional integration test for example_trainer.vllm_api_server /generate.""" + +from importlib import import_module + +import pytest +from fastapi.testclient import TestClient + + +@pytest.mark.asyncio +async def test_vllm_api_server_generate_endpoint_optional(): + """ + Validate /generate contract on the custom vLLM API server. + + This test only runs when vLLM is installed. + """ + pytest.importorskip("vllm") + + module = import_module("example_trainer.vllm_api_server") + + class _FakeLogprob: + def __init__(self, value: float): + self.logprob = value + + class _FakeOutput: + def __init__(self): + self.text = " world" + self.finish_reason = "stop" + self.logprobs = [{11: _FakeLogprob(-0.3)}] + self.token_ids = [11] + + class _FakeRequestOutput: + def __init__(self): + self.prompt = "hello" + self.prompt_token_ids = [1, 2] + self.outputs = [_FakeOutput()] + + class _FakeEngine: + tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})() + + def generate(self, *_args, **_kwargs): + async def _gen(): + yield _FakeRequestOutput() + + return _gen() + + old_engine = module.engine + module.engine = _FakeEngine() + try: + client = TestClient(module.app) + resp = client.post( + "/generate", + json={ + "prompt": "hello", + "max_tokens": 1, + "temperature": 0.0, + "logprobs": 1, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert "text" in body and body["text"] == [" world"] + assert body["prompt"] == "hello" + assert body["finish_reasons"] == ["stop"] + assert "logprobs" in body + assert "token_ids" in body + assert "prompt_token_ids" in body + finally: + module.engine = old_engine