From 16ac6f0936376f8f61f1bc2013812ebe84dbbd59 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Thu, 26 Feb 2026 21:48:56 +0300 Subject: [PATCH 01/18] fix: use sys.executable instead of hardcoded "python" in tests Tests that launch the API server via subprocess used a hardcoded "python" command which fails on systems where only "python3" is available (e.g. macOS). Using sys.executable ensures the same interpreter running pytest is used for subprocesses. Fixes 36 test errors on macOS environments. --- atroposlib/tests/api_test_utils.py | 3 ++- atroposlib/tests/test_api_compression.py | 3 ++- atroposlib/tests/test_api_messages_handling.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/atroposlib/tests/api_test_utils.py b/atroposlib/tests/api_test_utils.py index c88d51f2..200ada8c 100644 --- a/atroposlib/tests/api_test_utils.py +++ b/atroposlib/tests/api_test_utils.py @@ -1,4 +1,5 @@ import subprocess +import sys import time import requests @@ -16,7 +17,7 @@ def launch_api_for_testing(max_wait_for_api: int = 10) -> subprocess.Popen: # Use subprocess instead of multiprocessing to avoid inheriting pytest args api_proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_compression.py b/atroposlib/tests/test_api_compression.py index 25cbdfe0..0202c3ad 100644 --- a/atroposlib/tests/test_api_compression.py +++ b/atroposlib/tests/test_api_compression.py @@ -5,6 +5,7 @@ import json import os import signal import subprocess +import sys import time import pytest @@ -27,7 +28,7 @@ def wait_for_api_server(max_wait=10): def api_server(): proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_messages_handling.py b/atroposlib/tests/test_api_messages_handling.py index 0f7ac922..228be3ea 100644 --- a/atroposlib/tests/test_api_messages_handling.py +++ b/atroposlib/tests/test_api_messages_handling.py @@ -5,6 +5,7 @@ Tests for API server message handling, particularly for SFT (Supervised Fine-Tun import os import signal import subprocess +import sys import time import pytest @@ -30,7 +31,7 @@ def api_server(): # Start the API server as a subprocess proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", From b9291aa29f7de0e9b859279ffa971b428eae36a7 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 11:32:09 -0500 Subject: [PATCH 02/18] init commit --- .../envs/server_handling/managed_server.py | 91 ++++++++++++ .../envs/server_handling/server_baseline.py | 40 +++++ .../envs/server_handling/server_manager.py | 60 ++++++++ .../envs/server_handling/vllm_server.py | 138 ++++++++++++++++++ atroposlib/tests/test_managed_server.py | 28 ++++ 5 files changed, 357 insertions(+) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index c1358dc6..db909c83 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -529,6 +529,78 @@ class ManagedServer: else: self.current_nodes.clear() + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Fetch logprobs via wrapped server with a normalized trainer-agnostic schema. + + Supported inputs: + - prompt + - messages (converted to prompt) + - input_ids + + Returns: + Dict with: + - prompt_tokens + - sequence_token_ids + - sequence_logprobs + - sequence_topk_token_ids + - sequence_topk_logprobs + - finish_reasons + """ + request_kwargs = kwargs.copy() + messages = request_kwargs.pop("messages", None) + + if messages is not None: + prompt = self._convert_messages_to_prompt(messages) + request_kwargs["prompt"] = prompt + else: + prompt = request_kwargs.get("prompt") + + # Reuse tracked context in non-tree mode when possible. + if ( + not self.track_tree + and self.tokenizer is not None + and "input_ids" not in request_kwargs + and prompt is not None + ): + extending_node = self._find_extending_node(prompt) + request_kwargs["input_ids"] = self._compute_input_ids(prompt, extending_node) + + if hasattr(self.server, "get_logprobs"): + payload = await self.server.get_logprobs(**request_kwargs) + else: + # Backwards-compatible fallback for harness/test doubles. + ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons, + ) = await self.server.tokens_and_logprobs_completion(**request_kwargs) + payload = { + "prompt_tokens": prompt_tokens, + "sequence_token_ids": output_tokens_list, + "sequence_logprobs": output_logprobs_list, + "sequence_topk_token_ids": [ + [[tok] for tok in seq] for seq in output_tokens_list + ], + "sequence_topk_logprobs": [ + [[lp] for lp in seq] for seq in output_logprobs_list + ], + "finish_reasons": finish_reasons, + } + + # Normalize required keys if provider omitted top-k arrays. + if "sequence_topk_token_ids" not in payload: + payload["sequence_topk_token_ids"] = [ + [[tok] for tok in seq] for seq in payload["sequence_token_ids"] + ] + if "sequence_topk_logprobs" not in payload: + payload["sequence_topk_logprobs"] = [ + [[lp] for lp in seq] for seq in payload["sequence_logprobs"] + ] + + return payload + class DummyManagedServer: """ @@ -640,6 +712,25 @@ class DummyManagedServer: else: self.current_nodes.clear() + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Return interface-compatible dummy logprob payload. + + This keeps interface parity with ManagedServer while making it explicit + that results are placeholders and not suitable for training. + """ + n = int(kwargs.get("n", 1)) + seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)] + seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)] + return { + "prompt_tokens": [], + "sequence_token_ids": seq_ids, + "sequence_logprobs": seq_lps, + "sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids], + "sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps], + "finish_reasons": ["stop"] * n, + } + class ManagedServerAdapter: """ diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 3040c9ca..281f4fbd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -638,3 +638,43 @@ class APIServer(ABC): self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Trainer-agnostic logprob API with normalized output schema. + + This default implementation is built from `tokens_and_logprobs_completion` + and returns sampled-token logprobs (top-k singleton per position). + + Returns: + Dict with: + - prompt_tokens: List[int] + - sequence_token_ids: List[List[int]] + - sequence_logprobs: List[List[float]] + - sequence_topk_token_ids: List[List[List[int]]] + - sequence_topk_logprobs: List[List[List[float]]] + - finish_reasons: List[Any] + """ + ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons, + ) = await self.tokens_and_logprobs_completion(**kwargs) + + topk_token_ids = [ + [[token_id] for token_id in seq_tokens] for seq_tokens in output_tokens_list + ] + topk_logprobs = [ + [[logprob] for logprob in seq_logprobs] + for seq_logprobs in output_logprobs_list + ] + + return { + "prompt_tokens": prompt_tokens, + "sequence_token_ids": output_tokens_list, + "sequence_logprobs": output_logprobs_list, + "sequence_topk_token_ids": topk_token_ids, + "sequence_topk_logprobs": topk_logprobs, + "finish_reasons": finish_reasons, + } diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index e76dea32..6d0a90d5 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -346,6 +346,66 @@ class ServerManager: **kwargs ) + async def get_logprobs(self, **kwargs) -> dict: + """ + Route normalized get_logprobs requests to the most available server. + + Returns a normalized dict with: + - prompt_tokens + - sequence_token_ids + - sequence_logprobs + - sequence_topk_token_ids + - sequence_topk_logprobs + - finish_reasons + """ + n = kwargs.get("n", 1) + if n > self.max_n_completions: + # Split into multiple requests and merge sequence-level outputs. + results = [] + total_n = n + while total_n > 0: + n_to_use = min(total_n, self.max_n_completions) + kwargs["n"] = n_to_use + results.append(self.get_logprobs(**kwargs)) + total_n -= n_to_use + results = await asyncio.gather(*results) + merged = { + "prompt_tokens": results[0]["prompt_tokens"], + "sequence_token_ids": [], + "sequence_logprobs": [], + "sequence_topk_token_ids": [], + "sequence_topk_logprobs": [], + "finish_reasons": [], + } + for result in results: + merged["sequence_token_ids"].extend(result["sequence_token_ids"]) + merged["sequence_logprobs"].extend(result["sequence_logprobs"]) + merged["sequence_topk_token_ids"].extend( + result["sequence_topk_token_ids"] + ) + merged["sequence_topk_logprobs"].extend( + result["sequence_topk_logprobs"] + ) + merged["finish_reasons"].extend(result["finish_reasons"]) + return merged + + is_train = kwargs.pop("split", "train") == "train" + most_available_server = 0 + most_available_server_num_slots = -1 + await self.wait_for_sem(is_train) + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if ( + server.sem._value if is_train else server.eval_sem._value + ) > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = ( + server.sem._value if is_train else server.eval_sem._value + ) + + return await self.servers[most_available_server].get_logprobs(**kwargs) + @asynccontextmanager async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]: most_available_server = 0 diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 96242754..11ba87e8 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -3,6 +3,7 @@ import asyncio import warnings +from typing import Any, Dict, List, Tuple import aiohttp import openai @@ -231,6 +232,143 @@ class VLLMServer(APIServer): finish_reasons_list, ) + @staticmethod + def _normalize_topk_entry( + token_logprobs_entry: Any, + ) -> Tuple[List[int], List[float]]: + """ + Normalize a single token-position logprob payload into parallel top-k arrays. + + Supports common structures from vLLM responses: + - dict: {token_id: logprob, ...} + - list[dict]: [{token_id: logprob}, ...] + """ + if isinstance(token_logprobs_entry, dict): + items = list(token_logprobs_entry.items()) + return [int(k) for k, _ in items], [float(v) for _, v in items] + + if isinstance(token_logprobs_entry, list): + token_ids: List[int] = [] + logprobs: List[float] = [] + for item in token_logprobs_entry: + if not isinstance(item, dict): + continue + for key, value in item.items(): + token_ids.append(int(key)) + logprobs.append(float(value)) + return token_ids, logprobs + + return [], [] + + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + """ + Fetch normalized logprobs from vLLM /generate with optional top-k. + + Args: + top_k / top_logprobs: Optional number of logprobs per position. + Defaults to 1. + prompt or input_ids: Input text or token IDs. + + Returns: + Normalized dict: + - prompt_tokens + - sequence_token_ids + - sequence_logprobs + - sequence_topk_token_ids + - sequence_topk_logprobs + - finish_reasons + """ + assert ( + kwargs.get("prompt", None) is not None + or kwargs.get("input_ids", None) is not None + ), "Prompt or input_ids is required for get_logprobs!" + + top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) + top_k = max(1, top_k) + + # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt + if "input_ids" in kwargs: + prompt_tokens = kwargs.pop("input_ids") + kwargs.pop("prompt", None) + else: + prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + + # Check for double BOS token. + if ( + len(prompt_tokens) >= 2 + and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] + ): + prompt_tokens = prompt_tokens[1:] + + if "max_new_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_new_tokens") + if "max_completion_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") + kwargs.pop("model", None) + + request_data = { + "prompt": {"prompt_token_ids": prompt_tokens}, + "logprobs": top_k, + } + request_data.update(kwargs) + + # Keep semaphore behavior consistent with other server calls. + split = request_data.pop("split", "train") + sem = self.sem if split == "train" else self.eval_sem + while not self.server_healthy: + await asyncio.sleep(1) + + async with sem: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=request_data, + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + ) as response: + response.raise_for_status() + results = await response.json() + + sequence_topk_token_ids: List[List[List[int]]] = [] + sequence_topk_logprobs: List[List[List[float]]] = [] + sequence_token_ids: List[List[int]] = [] + sequence_logprobs: List[List[float]] = [] + finish_reasons: List[Any] = [] + + for token_logprobs_seq, finish_reason in zip( + results["logprobs"], results["finish_reasons"] + ): + seq_topk_token_ids: List[List[int]] = [] + seq_topk_logprobs: List[List[float]] = [] + seq_token_ids: List[int] = [] + seq_logprobs: List[float] = [] + + for token_logprobs_entry in token_logprobs_seq: + topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry) + seq_topk_token_ids.append(topk_ids) + seq_topk_logprobs.append(topk_lps) + seq_token_ids.append(topk_ids[0] if topk_ids else -1) + seq_logprobs.append(topk_lps[0] if topk_lps else 0.0) + + sequence_topk_token_ids.append(seq_topk_token_ids) + sequence_topk_logprobs.append(seq_topk_logprobs) + sequence_token_ids.append(seq_token_ids) + sequence_logprobs.append(seq_logprobs) + finish_reasons.append(finish_reason) + + return { + "prompt_tokens": prompt_tokens, + "sequence_token_ids": sequence_token_ids, + "sequence_logprobs": sequence_logprobs, + "sequence_topk_token_ids": sequence_topk_token_ids, + "sequence_topk_logprobs": sequence_topk_logprobs, + "finish_reasons": finish_reasons, + } + def resolve_openai_configs( default_server_configs, diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 75f8d48c..d8ea2aed 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -265,6 +265,34 @@ async def test_bos_token_handling(mock_server): assert mock_server.tokenizer.bos_token_id not in node.tokens[1:] +@pytest.mark.asyncio +async def test_get_logprobs_normalized_schema(mock_server): + """ManagedServer.get_logprobs returns normalized schema.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + output_tokens = [[ord("!"), ord("?")]] + output_logprobs = [[-0.1, -0.2]] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=output_tokens, + output_logprobs_list=output_logprobs, + finish_reasons=["stop"], + ) + + payload = await managed.get_logprobs(prompt=prompt, n=1) + + assert payload["prompt_tokens"] == prompt_tokens + assert payload["sequence_token_ids"] == output_tokens + assert payload["sequence_logprobs"] == output_logprobs + assert payload["finish_reasons"] == ["stop"] + assert payload["sequence_topk_token_ids"] == [[[ord("!")], [ord("?")]]] + assert payload["sequence_topk_logprobs"] == [[[-0.1], [-0.2]]] + + @pytest.mark.asyncio async def test_reset_clears_sequences(mock_server): """Test that reset() clears all tracked sequences.""" From 323a8a26017105d197c661e4cb1c8c2e010a21c4 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 12:19:55 -0500 Subject: [PATCH 03/18] readme updates --- .../envs/server_handling/MANAGED_SERVER.md | 42 +++++++++++++++---- atroposlib/envs/server_handling/README.md | 13 ++++++ 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 692f3a10..148444b1 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -2,7 +2,7 @@ ## Overview -`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. +`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. **Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection. @@ -36,7 +36,7 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions - ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0 -- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly +- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers - ✅ **Multi-turn Support**: Automatically handles conversation extensions - ✅ **Branching Support**: Handles n>1 completions naturally - ✅ **Clean API**: Simple context manager pattern @@ -609,6 +609,31 @@ Intercept completion call and track sequences. **Side Effects:** - Tracks sequences in internal storage +#### `async def get_logprobs(**kwargs) -> Dict[str, Any]` +Fetch logprobs with a normalized schema that is backend-agnostic. + +**Args (common):** +- `messages` or `prompt` or `input_ids` +- `n`: Number of sampled sequences +- `max_tokens` +- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop` + +**Returns (normalized):** +```python +{ + "prompt_tokens": List[int], + "sequence_token_ids": List[List[int]], # [seq][pos] + "sequence_logprobs": List[List[float]], # [seq][pos] + "sequence_topk_token_ids": List[List[List[int]]], # [seq][pos][k] + "sequence_topk_logprobs": List[List[List[float]]], # [seq][pos][k] + "finish_reasons": List[Any], +} +``` + +**Notes:** +- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton top-k arrays. +- This method is for transport/interface consistency; richer top-k depends on backend support. + #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. @@ -712,11 +737,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1 With this flag set, `managed_server()` will return a `DummyManagedServer` that: - Provides the same interface as `ManagedServer` -- Returns **fixed placeholder values** for tokens and logprobs: - - `tokens`: `[1, 2, 3]` - - `masked_tokens`: `[-100, 2, 3]` - - `logprobs`: `[-0.5, -0.5, -0.5]` +- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays) - Uses simple text formatting for `full_text`: `role:content` joined by `\n\n` +- Implements `get_logprobs(...)` with the same normalized keys as `ManagedServer`, but placeholder values (not suitable for training) ### When to Use DummyManagedServer @@ -746,8 +769,11 @@ async with self.server.managed_server() as managed: # nodes contain placeholder token data - DO NOT use for training for node in nodes: print(node.full_text) # Real completion text - print(node.tokens) # [1, 2, 3] - placeholder! - print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder! + print(node.tokens[:5]) # placeholder values + print(node.logprobs[:5]) # placeholder values + + payload = await managed.get_logprobs(messages=messages, n=4) + print(payload.keys()) # normalized schema keys ``` ### Recommendation diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index 447b1127..23d6b5d0 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -8,6 +8,19 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ > **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details. +### Normalized `get_logprobs` API + +`ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: + +- `prompt_tokens` +- `sequence_token_ids` +- `sequence_logprobs` +- `sequence_topk_token_ids` +- `sequence_topk_logprobs` +- `finish_reasons` + +For backends that only expose sampled-token logprobs, top-k arrays are synthesized with `k=1` for interface compatibility. + ## Reasoning Model Support The `ReasoningConfig` class enables support for reasoning/thinking models across different providers. From e98100e5f6b20c4d31f947e150a1bc7054398e78 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:20:56 +0000 Subject: [PATCH 04/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/envs/server_handling/managed_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index db909c83..38b94489 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -564,7 +564,9 @@ class ManagedServer: and prompt is not None ): extending_node = self._find_extending_node(prompt) - request_kwargs["input_ids"] = self._compute_input_ids(prompt, extending_node) + request_kwargs["input_ids"] = self._compute_input_ids( + prompt, extending_node + ) if hasattr(self.server, "get_logprobs"): payload = await self.server.get_logprobs(**request_kwargs) From 439b9b129b50636c6082d81660c81357e8b7fa7f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 21:56:11 -0500 Subject: [PATCH 05/18] prompt logprobs --- .../envs/server_handling/MANAGED_SERVER.md | 12 ++-- atroposlib/envs/server_handling/README.md | 8 +-- .../envs/server_handling/managed_server.py | 53 ++++++---------- .../envs/server_handling/server_baseline.py | 35 ++++------- .../envs/server_handling/server_manager.py | 31 ++------- .../envs/server_handling/vllm_server.py | 63 +++++++------------ atroposlib/tests/test_managed_server.py | 9 +-- 7 files changed, 73 insertions(+), 138 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 148444b1..6b739a83 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -622,17 +622,15 @@ Fetch logprobs with a normalized schema that is backend-agnostic. ```python { "prompt_tokens": List[int], - "sequence_token_ids": List[List[int]], # [seq][pos] - "sequence_logprobs": List[List[float]], # [seq][pos] - "sequence_topk_token_ids": List[List[List[int]]], # [seq][pos][k] - "sequence_topk_logprobs": List[List[List[float]]], # [seq][pos][k] - "finish_reasons": List[Any], + "prompt_topk_token_ids": List[List[int]], # [pos][k] + "prompt_topk_logprobs": List[List[float]], # [pos][k] + "finish_reasons": List[Any], # optional compatibility field } ``` **Notes:** -- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton top-k arrays. -- This method is for transport/interface consistency; richer top-k depends on backend support. +- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton prompt top-k arrays. +- This method is prompt-focused; richer top-k depends on backend support. #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index 23d6b5d0..f19ad272 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -13,13 +13,11 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ `ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: - `prompt_tokens` -- `sequence_token_ids` -- `sequence_logprobs` -- `sequence_topk_token_ids` -- `sequence_topk_logprobs` +- `prompt_topk_token_ids` +- `prompt_topk_logprobs` - `finish_reasons` -For backends that only expose sampled-token logprobs, top-k arrays are synthesized with `k=1` for interface compatibility. +For backends that only expose sampled-token logprobs, prompt top-k arrays are synthesized with `k=1` for interface compatibility. ## Reasoning Model Support diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 38b94489..f2231ab7 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -531,7 +531,7 @@ class ManagedServer: async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Fetch logprobs via wrapped server with a normalized trainer-agnostic schema. + Fetch prompt logprobs via wrapped server with a normalized schema. Supported inputs: - prompt @@ -541,11 +541,8 @@ class ManagedServer: Returns: Dict with: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ request_kwargs = kwargs.copy() messages = request_kwargs.pop("messages", None) @@ -574,31 +571,24 @@ class ManagedServer: # Backwards-compatible fallback for harness/test doubles. ( prompt_tokens, - output_tokens_list, - output_logprobs_list, - finish_reasons, + _output_tokens_list, + _output_logprobs_list, + _finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**request_kwargs) payload = { "prompt_tokens": prompt_tokens, - "sequence_token_ids": output_tokens_list, - "sequence_logprobs": output_logprobs_list, - "sequence_topk_token_ids": [ - [[tok] for tok in seq] for seq in output_tokens_list - ], - "sequence_topk_logprobs": [ - [[lp] for lp in seq] for seq in output_logprobs_list - ], - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + "prompt_topk_logprobs": [[1.0] for _ in prompt_tokens], } - # Normalize required keys if provider omitted top-k arrays. - if "sequence_topk_token_ids" not in payload: - payload["sequence_topk_token_ids"] = [ - [[tok] for tok in seq] for seq in payload["sequence_token_ids"] + # Normalize required keys if provider omitted prompt top-k arrays. + if "prompt_topk_token_ids" not in payload: + payload["prompt_topk_token_ids"] = [ + [tok] for tok in payload.get("prompt_tokens", []) ] - if "sequence_topk_logprobs" not in payload: - payload["sequence_topk_logprobs"] = [ - [[lp] for lp in seq] for seq in payload["sequence_logprobs"] + if "prompt_topk_logprobs" not in payload: + payload["prompt_topk_logprobs"] = [ + [1.0] for _ in payload.get("prompt_tokens", []) ] return payload @@ -722,15 +712,12 @@ class DummyManagedServer: that results are placeholders and not suitable for training. """ n = int(kwargs.get("n", 1)) - seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)] - seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)] + prompt_tokens = self.DUMMY_TOKENS[:] return { - "prompt_tokens": [], - "sequence_token_ids": seq_ids, - "sequence_logprobs": seq_lps, - "sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids], - "sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps], - "finish_reasons": ["stop"] * n, + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + "prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens], + "finish_reasons": ["stop"] * n, # Retained for compatibility in callers. } diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 281f4fbd..c387f5dd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -641,40 +641,31 @@ class APIServer(ABC): async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Trainer-agnostic logprob API with normalized output schema. + Trainer-agnostic prompt-logprob API with normalized output schema. This default implementation is built from `tokens_and_logprobs_completion` - and returns sampled-token logprobs (top-k singleton per position). + and returns prompt-side singleton top-k values. Returns: Dict with: - prompt_tokens: List[int] - - sequence_token_ids: List[List[int]] - - sequence_logprobs: List[List[float]] - - sequence_topk_token_ids: List[List[List[int]]] - - sequence_topk_logprobs: List[List[List[float]]] - - finish_reasons: List[Any] + - prompt_topk_token_ids: List[List[int]] + - prompt_topk_logprobs: List[List[float]] """ ( prompt_tokens, - output_tokens_list, - output_logprobs_list, - finish_reasons, + _output_tokens_list, + _output_logprobs_list, + _finish_reasons, ) = await self.tokens_and_logprobs_completion(**kwargs) - topk_token_ids = [ - [[token_id] for token_id in seq_tokens] for seq_tokens in output_tokens_list - ] - topk_logprobs = [ - [[logprob] for logprob in seq_logprobs] - for seq_logprobs in output_logprobs_list - ] + # Fallback path does not have true prompt-logprobs, so we provide + # interface-compatible singleton values for each prompt token. + prompt_topk_token_ids = [[token_id] for token_id in prompt_tokens] + prompt_topk_logprobs = [[1.0] for _ in prompt_tokens] return { "prompt_tokens": prompt_tokens, - "sequence_token_ids": output_tokens_list, - "sequence_logprobs": output_logprobs_list, - "sequence_topk_token_ids": topk_token_ids, - "sequence_topk_logprobs": topk_logprobs, - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, } diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 6d0a90d5..4c548aac 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -348,19 +348,16 @@ class ServerManager: async def get_logprobs(self, **kwargs) -> dict: """ - Route normalized get_logprobs requests to the most available server. + Route normalized prompt-logprob requests to the most available server. Returns a normalized dict with: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ n = kwargs.get("n", 1) if n > self.max_n_completions: - # Split into multiple requests and merge sequence-level outputs. + # Prompt logprobs are prompt-level; n-splitting does not change prompt arrays. results = [] total_n = n while total_n > 0: @@ -369,25 +366,7 @@ class ServerManager: results.append(self.get_logprobs(**kwargs)) total_n -= n_to_use results = await asyncio.gather(*results) - merged = { - "prompt_tokens": results[0]["prompt_tokens"], - "sequence_token_ids": [], - "sequence_logprobs": [], - "sequence_topk_token_ids": [], - "sequence_topk_logprobs": [], - "finish_reasons": [], - } - for result in results: - merged["sequence_token_ids"].extend(result["sequence_token_ids"]) - merged["sequence_logprobs"].extend(result["sequence_logprobs"]) - merged["sequence_topk_token_ids"].extend( - result["sequence_topk_token_ids"] - ) - merged["sequence_topk_logprobs"].extend( - result["sequence_topk_logprobs"] - ) - merged["finish_reasons"].extend(result["finish_reasons"]) - return merged + return results[0] is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 11ba87e8..e0091ead 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -262,7 +262,7 @@ class VLLMServer(APIServer): async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Fetch normalized logprobs from vLLM /generate with optional top-k. + Fetch normalized prompt logprobs from vLLM /generate with optional top-k. Args: top_k / top_logprobs: Optional number of logprobs per position. @@ -272,11 +272,8 @@ class VLLMServer(APIServer): Returns: Normalized dict: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ assert ( kwargs.get("prompt", None) is not None @@ -306,10 +303,8 @@ class VLLMServer(APIServer): kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") kwargs.pop("model", None) - request_data = { - "prompt": {"prompt_token_ids": prompt_tokens}, - "logprobs": top_k, - } + request_data = {"prompt": {"prompt_token_ids": prompt_tokens}} + request_data["prompt_logprobs"] = top_k request_data.update(kwargs) # Keep semaphore behavior consistent with other server calls. @@ -333,40 +328,30 @@ class VLLMServer(APIServer): response.raise_for_status() results = await response.json() - sequence_topk_token_ids: List[List[List[int]]] = [] - sequence_topk_logprobs: List[List[List[float]]] = [] - sequence_token_ids: List[List[int]] = [] - sequence_logprobs: List[List[float]] = [] - finish_reasons: List[Any] = [] + raw_prompt_logprobs = results.get("prompt_logprobs") + if raw_prompt_logprobs is None: + raise ValueError( + "vLLM /generate response missing 'prompt_logprobs'. " + "Ensure backend supports prompt logprobs." + ) - for token_logprobs_seq, finish_reason in zip( - results["logprobs"], results["finish_reasons"] - ): - seq_topk_token_ids: List[List[int]] = [] - seq_topk_logprobs: List[List[float]] = [] - seq_token_ids: List[int] = [] - seq_logprobs: List[float] = [] + # Handle either direct [position] payloads or [sequence][position] payloads. + if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list): + prompt_entries = raw_prompt_logprobs[0] + else: + prompt_entries = raw_prompt_logprobs - for token_logprobs_entry in token_logprobs_seq: - topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry) - seq_topk_token_ids.append(topk_ids) - seq_topk_logprobs.append(topk_lps) - seq_token_ids.append(topk_ids[0] if topk_ids else -1) - seq_logprobs.append(topk_lps[0] if topk_lps else 0.0) - - sequence_topk_token_ids.append(seq_topk_token_ids) - sequence_topk_logprobs.append(seq_topk_logprobs) - sequence_token_ids.append(seq_token_ids) - sequence_logprobs.append(seq_logprobs) - finish_reasons.append(finish_reason) + prompt_topk_token_ids: List[List[int]] = [] + prompt_topk_logprobs: List[List[float]] = [] + for entry in prompt_entries: + topk_ids, topk_lps = self._normalize_topk_entry(entry) + prompt_topk_token_ids.append(topk_ids) + prompt_topk_logprobs.append(topk_lps) return { "prompt_tokens": prompt_tokens, - "sequence_token_ids": sequence_token_ids, - "sequence_logprobs": sequence_logprobs, - "sequence_topk_token_ids": sequence_topk_token_ids, - "sequence_topk_logprobs": sequence_topk_logprobs, - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, } diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index d8ea2aed..ce644a05 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -267,7 +267,7 @@ async def test_bos_token_handling(mock_server): @pytest.mark.asyncio async def test_get_logprobs_normalized_schema(mock_server): - """ManagedServer.get_logprobs returns normalized schema.""" + """ManagedServer.get_logprobs returns normalized prompt schema.""" managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) prompt = "Hello" @@ -286,11 +286,8 @@ async def test_get_logprobs_normalized_schema(mock_server): payload = await managed.get_logprobs(prompt=prompt, n=1) assert payload["prompt_tokens"] == prompt_tokens - assert payload["sequence_token_ids"] == output_tokens - assert payload["sequence_logprobs"] == output_logprobs - assert payload["finish_reasons"] == ["stop"] - assert payload["sequence_topk_token_ids"] == [[[ord("!")], [ord("?")]]] - assert payload["sequence_topk_logprobs"] == [[[-0.1], [-0.2]]] + assert payload["prompt_topk_token_ids"] == [[tok] for tok in prompt_tokens] + assert payload["prompt_topk_logprobs"] == [[1.0] for _ in prompt_tokens] @pytest.mark.asyncio From f1c20591b69c314cdaed2e516846d975623e5202 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 21:57:52 -0500 Subject: [PATCH 06/18] prompt logprobs --- atroposlib/envs/server_handling/MANAGED_SERVER.md | 1 - atroposlib/envs/server_handling/README.md | 1 - atroposlib/envs/server_handling/managed_server.py | 2 -- 3 files changed, 4 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 6b739a83..aac66249 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -624,7 +624,6 @@ Fetch logprobs with a normalized schema that is backend-agnostic. "prompt_tokens": List[int], "prompt_topk_token_ids": List[List[int]], # [pos][k] "prompt_topk_logprobs": List[List[float]], # [pos][k] - "finish_reasons": List[Any], # optional compatibility field } ``` diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index f19ad272..4a25be8d 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -15,7 +15,6 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ - `prompt_tokens` - `prompt_topk_token_ids` - `prompt_topk_logprobs` -- `finish_reasons` For backends that only expose sampled-token logprobs, prompt top-k arrays are synthesized with `k=1` for interface compatibility. diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index f2231ab7..79b90b4e 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -711,13 +711,11 @@ class DummyManagedServer: This keeps interface parity with ManagedServer while making it explicit that results are placeholders and not suitable for training. """ - n = int(kwargs.get("n", 1)) prompt_tokens = self.DUMMY_TOKENS[:] return { "prompt_tokens": prompt_tokens, "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], "prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens], - "finish_reasons": ["stop"] * n, # Retained for compatibility in callers. } From 5aaf7a346cd5c1c6142edfa54e5900deea83b5ef Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 22:06:49 -0500 Subject: [PATCH 07/18] prompt logprobs simplicity --- .../envs/server_handling/MANAGED_SERVER.md | 4 +- atroposlib/envs/server_handling/README.md | 2 +- .../envs/server_handling/managed_server.py | 81 +++++++++++-------- .../envs/server_handling/server_baseline.py | 27 ++----- atroposlib/tests/test_managed_server.py | 24 +++--- 5 files changed, 69 insertions(+), 69 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index aac66249..94cfa15e 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -628,8 +628,8 @@ Fetch logprobs with a normalized schema that is backend-agnostic. ``` **Notes:** -- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton prompt top-k arrays. -- This method is prompt-focused; richer top-k depends on backend support. +- Strict mode: backend must provide real prompt top-k arrays with aligned shapes. +- Missing fields or shape mismatches fail fast with explicit errors. #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index 4a25be8d..f6e74193 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -16,7 +16,7 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ - `prompt_topk_token_ids` - `prompt_topk_logprobs` -For backends that only expose sampled-token logprobs, prompt top-k arrays are synthesized with `k=1` for interface compatibility. +Strict mode: backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast. ## Reasoning Model Support diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 79b90b4e..84752516 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -565,34 +565,52 @@ class ManagedServer: prompt, extending_node ) - if hasattr(self.server, "get_logprobs"): - payload = await self.server.get_logprobs(**request_kwargs) - else: - # Backwards-compatible fallback for harness/test doubles. - ( - prompt_tokens, - _output_tokens_list, - _output_logprobs_list, - _finish_reasons, - ) = await self.server.tokens_and_logprobs_completion(**request_kwargs) - payload = { - "prompt_tokens": prompt_tokens, - "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], - "prompt_topk_logprobs": [[1.0] for _ in prompt_tokens], - } + if not hasattr(self.server, "get_logprobs"): + raise NotImplementedError( + f"{self.server.__class__.__name__} does not implement get_logprobs. " + "Strict mode requires backend prompt logprobs." + ) - # Normalize required keys if provider omitted prompt top-k arrays. - if "prompt_topk_token_ids" not in payload: - payload["prompt_topk_token_ids"] = [ - [tok] for tok in payload.get("prompt_tokens", []) - ] - if "prompt_topk_logprobs" not in payload: - payload["prompt_topk_logprobs"] = [ - [1.0] for _ in payload.get("prompt_tokens", []) - ] + payload = await self.server.get_logprobs(**request_kwargs) + self._validate_prompt_logprob_payload(payload) return payload + @staticmethod + def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None: + required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs") + missing = [k for k in required if k not in payload] + if missing: + raise ValueError( + f"get_logprobs response missing required keys: {missing}" + ) + + prompt_tokens = payload["prompt_tokens"] + token_ids = payload["prompt_topk_token_ids"] + logprobs = payload["prompt_topk_logprobs"] + + if not isinstance(prompt_tokens, list): + raise ValueError("prompt_tokens must be a list[int].") + if not isinstance(token_ids, list) or not isinstance(logprobs, list): + raise ValueError( + "prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list." + ) + if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens): + raise ValueError( + "prompt_topk arrays must align with prompt_tokens length." + ) + + for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)): + if not isinstance(tok_row, list) or not isinstance(lp_row, list): + raise ValueError( + "prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list." + ) + if len(tok_row) != len(lp_row): + raise ValueError( + f"prompt_topk row mismatch at position {idx}: " + f"{len(tok_row)} token ids vs {len(lp_row)} logprobs." + ) + class DummyManagedServer: """ @@ -706,17 +724,12 @@ class DummyManagedServer: async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Return interface-compatible dummy logprob payload. - - This keeps interface parity with ManagedServer while making it explicit - that results are placeholders and not suitable for training. + Dummy managed server does not provide real prompt logprobs. """ - prompt_tokens = self.DUMMY_TOKENS[:] - return { - "prompt_tokens": prompt_tokens, - "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], - "prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens], - } + raise NotImplementedError( + "DummyManagedServer does not support get_logprobs in strict mode. " + "Use a backend with real prompt logprob support." + ) class ManagedServerAdapter: diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index c387f5dd..a8cef0dd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -641,10 +641,7 @@ class APIServer(ABC): async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Trainer-agnostic prompt-logprob API with normalized output schema. - - This default implementation is built from `tokens_and_logprobs_completion` - and returns prompt-side singleton top-k values. + Prompt-logprob API with strict normalized output schema. Returns: Dict with: @@ -652,20 +649,8 @@ class APIServer(ABC): - prompt_topk_token_ids: List[List[int]] - prompt_topk_logprobs: List[List[float]] """ - ( - prompt_tokens, - _output_tokens_list, - _output_logprobs_list, - _finish_reasons, - ) = await self.tokens_and_logprobs_completion(**kwargs) - - # Fallback path does not have true prompt-logprobs, so we provide - # interface-compatible singleton values for each prompt token. - prompt_topk_token_ids = [[token_id] for token_id in prompt_tokens] - prompt_topk_logprobs = [[1.0] for _ in prompt_tokens] - - return { - "prompt_tokens": prompt_tokens, - "prompt_topk_token_ids": prompt_topk_token_ids, - "prompt_topk_logprobs": prompt_topk_logprobs, - } + raise NotImplementedError( + f"{self.__class__.__name__}.get_logprobs must be implemented by the " + "server backend and must return prompt_tokens, " + "prompt_topk_token_ids, and prompt_topk_logprobs." + ) diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index ce644a05..fefe414c 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -272,22 +272,24 @@ async def test_get_logprobs_normalized_schema(mock_server): prompt = "Hello" prompt_tokens = mock_server.tokenizer.encode(prompt) - output_tokens = [[ord("!"), ord("?")]] - output_logprobs = [[-0.1, -0.2]] + prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens] + prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens] - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=output_tokens, - output_logprobs_list=output_logprobs, - finish_reasons=["stop"], - ) + async def _mock_get_logprobs(**kwargs): + assert kwargs.get("prompt") == prompt + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, + } + + mock_server.get_logprobs = _mock_get_logprobs payload = await managed.get_logprobs(prompt=prompt, n=1) assert payload["prompt_tokens"] == prompt_tokens - assert payload["prompt_topk_token_ids"] == [[tok] for tok in prompt_tokens] - assert payload["prompt_topk_logprobs"] == [[1.0] for _ in prompt_tokens] + assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids + assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs @pytest.mark.asyncio From 8f304d44fd3fda3324dbf2386721b54b520cc69e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 03:08:15 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/envs/server_handling/managed_server.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 84752516..e2828cf3 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -581,9 +581,7 @@ class ManagedServer: required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs") missing = [k for k in required if k not in payload] if missing: - raise ValueError( - f"get_logprobs response missing required keys: {missing}" - ) + raise ValueError(f"get_logprobs response missing required keys: {missing}") prompt_tokens = payload["prompt_tokens"] token_ids = payload["prompt_topk_token_ids"] @@ -596,9 +594,7 @@ class ManagedServer: "prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list." ) if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens): - raise ValueError( - "prompt_topk arrays must align with prompt_tokens length." - ) + raise ValueError("prompt_topk arrays must align with prompt_tokens length.") for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)): if not isinstance(tok_row, list) or not isinstance(lp_row, list): From 51088ac24d4d56462ddd6ea1069bc02f3e3d2099 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 22:13:14 -0500 Subject: [PATCH 09/18] add tests --- atroposlib/tests/test_managed_server.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index fefe414c..fe68e4fb 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -292,6 +292,28 @@ async def test_get_logprobs_normalized_schema(mock_server): assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs +@pytest.mark.asyncio +async def test_get_logprobs_strict_mode_rejects_misaligned_payload(mock_server): + """ManagedServer.get_logprobs fails fast on malformed prompt top-k payload.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + + async def _mock_get_logprobs(**kwargs): + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + # Missing one row on purpose -> misaligned with prompt length + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens[:-1]], + } + + mock_server.get_logprobs = _mock_get_logprobs + + with pytest.raises(ValueError, match="align with prompt_tokens length"): + await managed.get_logprobs(prompt=prompt, n=1) + + @pytest.mark.asyncio async def test_reset_clears_sequences(mock_server): """Test that reset() clears all tracked sequences.""" From 1eeb31065f2a00d66a875d020f36f560950be9ba Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 23:08:28 -0500 Subject: [PATCH 10/18] fixing comments --- .../envs/server_handling/MANAGED_SERVER.md | 11 ++--- atroposlib/envs/server_handling/README.md | 4 +- .../envs/server_handling/managed_server.py | 45 ------------------- .../envs/server_handling/server_manager.py | 13 ------ .../envs/server_handling/vllm_server.py | 11 ++++- atroposlib/tests/test_managed_server.py | 18 ++------ 6 files changed, 21 insertions(+), 81 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 94cfa15e..9b1c9c3e 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -36,6 +36,7 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions - ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0 +- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences - ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers - ✅ **Multi-turn Support**: Automatically handles conversation extensions - ✅ **Branching Support**: Handles n>1 completions naturally @@ -628,8 +629,8 @@ Fetch logprobs with a normalized schema that is backend-agnostic. ``` **Notes:** -- Strict mode: backend must provide real prompt top-k arrays with aligned shapes. -- Missing fields or shape mismatches fail fast with explicit errors. +- Strict mode: backend must provide real prompt top-k arrays. +- Missing keys should be treated as backend contract violations. #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. @@ -736,7 +737,7 @@ With this flag set, `managed_server()` will return a `DummyManagedServer` that: - Provides the same interface as `ManagedServer` - Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays) - Uses simple text formatting for `full_text`: `role:content` joined by `\n\n` -- Implements `get_logprobs(...)` with the same normalized keys as `ManagedServer`, but placeholder values (not suitable for training) +- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload) ### When to Use DummyManagedServer @@ -769,8 +770,8 @@ async with self.server.managed_server() as managed: print(node.tokens[:5]) # placeholder values print(node.logprobs[:5]) # placeholder values - payload = await managed.get_logprobs(messages=messages, n=4) - print(payload.keys()) # normalized schema keys + # Strict mode: get_logprobs is not available on DummyManagedServer + # and will raise NotImplementedError. ``` ### Recommendation diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index f6e74193..ee388293 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -10,13 +10,13 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ ### Normalized `get_logprobs` API -`ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: +`ManagedServer` and server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: - `prompt_tokens` - `prompt_topk_token_ids` - `prompt_topk_logprobs` -Strict mode: backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast. +Backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast. ## Reasoning Model Support diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index e2828cf3..eb4baecc 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -553,18 +553,6 @@ class ManagedServer: else: prompt = request_kwargs.get("prompt") - # Reuse tracked context in non-tree mode when possible. - if ( - not self.track_tree - and self.tokenizer is not None - and "input_ids" not in request_kwargs - and prompt is not None - ): - extending_node = self._find_extending_node(prompt) - request_kwargs["input_ids"] = self._compute_input_ids( - prompt, extending_node - ) - if not hasattr(self.server, "get_logprobs"): raise NotImplementedError( f"{self.server.__class__.__name__} does not implement get_logprobs. " @@ -572,41 +560,8 @@ class ManagedServer: ) payload = await self.server.get_logprobs(**request_kwargs) - self._validate_prompt_logprob_payload(payload) - return payload - @staticmethod - def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None: - required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs") - missing = [k for k in required if k not in payload] - if missing: - raise ValueError(f"get_logprobs response missing required keys: {missing}") - - prompt_tokens = payload["prompt_tokens"] - token_ids = payload["prompt_topk_token_ids"] - logprobs = payload["prompt_topk_logprobs"] - - if not isinstance(prompt_tokens, list): - raise ValueError("prompt_tokens must be a list[int].") - if not isinstance(token_ids, list) or not isinstance(logprobs, list): - raise ValueError( - "prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list." - ) - if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens): - raise ValueError("prompt_topk arrays must align with prompt_tokens length.") - - for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)): - if not isinstance(tok_row, list) or not isinstance(lp_row, list): - raise ValueError( - "prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list." - ) - if len(tok_row) != len(lp_row): - raise ValueError( - f"prompt_topk row mismatch at position {idx}: " - f"{len(tok_row)} token ids vs {len(lp_row)} logprobs." - ) - class DummyManagedServer: """ diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 4c548aac..a7c1416c 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -355,19 +355,6 @@ class ServerManager: - prompt_topk_token_ids - prompt_topk_logprobs """ - n = kwargs.get("n", 1) - if n > self.max_n_completions: - # Prompt logprobs are prompt-level; n-splitting does not change prompt arrays. - results = [] - total_n = n - while total_n > 0: - n_to_use = min(total_n, self.max_n_completions) - kwargs["n"] = n_to_use - results.append(self.get_logprobs(**kwargs)) - total_n -= n_to_use - results = await asyncio.gather(*results) - return results[0] - is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 most_available_server_num_slots = -1 diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index e0091ead..8a3f0c44 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -284,14 +284,18 @@ class VLLMServer(APIServer): top_k = max(1, top_k) # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt + from_prompt_text = False if "input_ids" in kwargs: prompt_tokens = kwargs.pop("input_ids") kwargs.pop("prompt", None) else: prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + from_prompt_text = True - # Check for double BOS token. + # Only normalize BOS for tokenizer-encoded prompt text. if ( + from_prompt_text + and len(prompt_tokens) >= 2 and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] ): @@ -306,6 +310,11 @@ class VLLMServer(APIServer): request_data = {"prompt": {"prompt_token_ids": prompt_tokens}} request_data["prompt_logprobs"] = top_k request_data.update(kwargs) + # This API is prompt-logprobs focused, not generation-focused. + request_data["n"] = 1 + request_data["temperature"] = 0.0 + request_data["top_p"] = 1.0 + request_data.setdefault("max_tokens", 1) # Keep semaphore behavior consistent with other server calls. split = request_data.pop("split", "train") diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index fe68e4fb..00d1e9b9 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -293,24 +293,12 @@ async def test_get_logprobs_normalized_schema(mock_server): @pytest.mark.asyncio -async def test_get_logprobs_strict_mode_rejects_misaligned_payload(mock_server): - """ManagedServer.get_logprobs fails fast on malformed prompt top-k payload.""" +async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): + """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) prompt = "Hello" - prompt_tokens = mock_server.tokenizer.encode(prompt) - - async def _mock_get_logprobs(**kwargs): - return { - "prompt_tokens": prompt_tokens, - "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], - # Missing one row on purpose -> misaligned with prompt length - "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens[:-1]], - } - - mock_server.get_logprobs = _mock_get_logprobs - - with pytest.raises(ValueError, match="align with prompt_tokens length"): + with pytest.raises(NotImplementedError, match="does not implement get_logprobs"): await managed.get_logprobs(prompt=prompt, n=1) From efc90bfb1bbf8c529945b823c6710b8e1dfc4373 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 04:18:08 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/envs/server_handling/vllm_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 8a3f0c44..257f3337 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -295,8 +295,7 @@ class VLLMServer(APIServer): # Only normalize BOS for tokenizer-encoded prompt text. if ( from_prompt_text - and - len(prompt_tokens) >= 2 + and len(prompt_tokens) >= 2 and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] ): prompt_tokens = prompt_tokens[1:] From 1a3d9ee664ab05988362cdd05b13a18386825288 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 23:37:53 -0500 Subject: [PATCH 12/18] testing --- .../tests/test_vllm_api_server_generate.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 atroposlib/tests/test_vllm_api_server_generate.py diff --git a/atroposlib/tests/test_vllm_api_server_generate.py b/atroposlib/tests/test_vllm_api_server_generate.py new file mode 100644 index 00000000..f2c9c48c --- /dev/null +++ b/atroposlib/tests/test_vllm_api_server_generate.py @@ -0,0 +1,68 @@ +"""Optional integration test for example_trainer.vllm_api_server /generate.""" + +from importlib import import_module + +import pytest +from fastapi.testclient import TestClient + + +@pytest.mark.asyncio +async def test_vllm_api_server_generate_endpoint_optional(): + """ + Validate /generate contract on the custom vLLM API server. + + This test only runs when vLLM is installed. + """ + pytest.importorskip("vllm") + + module = import_module("example_trainer.vllm_api_server") + + class _FakeLogprob: + def __init__(self, value: float): + self.logprob = value + + class _FakeOutput: + def __init__(self): + self.text = " world" + self.finish_reason = "stop" + self.logprobs = [{11: _FakeLogprob(-0.3)}] + self.token_ids = [11] + + class _FakeRequestOutput: + def __init__(self): + self.prompt = "hello" + self.prompt_token_ids = [1, 2] + self.outputs = [_FakeOutput()] + + class _FakeEngine: + tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})() + + def generate(self, *_args, **_kwargs): + async def _gen(): + yield _FakeRequestOutput() + + return _gen() + + old_engine = module.engine + module.engine = _FakeEngine() + try: + client = TestClient(module.app) + resp = client.post( + "/generate", + json={ + "prompt": "hello", + "max_tokens": 1, + "temperature": 0.0, + "logprobs": 1, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert "text" in body and body["text"] == [" world"] + assert body["prompt"] == "hello" + assert body["finish_reasons"] == ["stop"] + assert "logprobs" in body + assert "token_ids" in body + assert "prompt_token_ids" in body + finally: + module.engine = old_engine From c85a3e5ee79e29feee003ea5fe94b6cacb374a7f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 23:44:29 -0500 Subject: [PATCH 13/18] readme language --- atroposlib/envs/server_handling/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index ee388293..5b5775e9 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -10,13 +10,13 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ ### Normalized `get_logprobs` API -`ManagedServer` and server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: +`ManagedServer` and supported server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema: - `prompt_tokens` - `prompt_topk_token_ids` - `prompt_topk_logprobs` -Backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast. +Backends are expected to return real prompt top-k arrays (`[pos][k]`) matching this schema. ## Reasoning Model Support From b91922082e84070e6e9a1a3ab20fbd65f58e0f7f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 5 Mar 2026 15:46:33 -0500 Subject: [PATCH 14/18] managed_Server pass through and centralize sem logic --- .../envs/server_handling/server_baseline.py | 72 +++++++++++- .../envs/server_handling/vllm_server.py | 35 +++--- atroposlib/tests/test_managed_server.py | 24 ++++ atroposlib/tests/test_server_logprobs.py | 103 ++++++++++++++++++ 4 files changed, 208 insertions(+), 26 deletions(-) create mode 100644 atroposlib/tests/test_server_logprobs.py diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index a8cef0dd..ade862dd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -421,6 +421,15 @@ class APIServer(ABC): """ pass + async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]: + """ + Wrapper for prompt logprobs. Can be overridden by child classes. + Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement _get_logprobs_wrapper." + ) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -639,6 +648,40 @@ class APIServer(ABC): self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]: + """ + Simple retry and stat collection wrapper for get_logprobs. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + payload = await self._get_logprobs_wrapper(**kwargs) + stat_dict["end"] = time.time() + return payload + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]: + """ + Simple retry and stat collection wrapper for get_logprobs eval. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + payload = await self._get_logprobs_wrapper(**kwargs) + stat_dict["end"] = time.time() + return payload + async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ Prompt-logprob API with strict normalized output schema. @@ -649,8 +692,27 @@ class APIServer(ABC): - prompt_topk_token_ids: List[List[int]] - prompt_topk_logprobs: List[List[float]] """ - raise NotImplementedError( - f"{self.__class__.__name__}.get_logprobs must be implemented by the " - "server backend and must return prompt_tokens, " - "prompt_topk_token_ids, and prompt_topk_logprobs." - ) + if not self.initialized: + if self.config.health_check: + if self.config.base_url is not None: + self.check_task = asyncio.create_task( + self.check_server_status_task(chat_completion=False) + ) + else: + self.server_healthy = True + else: + self.server_healthy = True + self.initialized = True + + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {"attempts": 0} + if split == "train": + payload = await self._logprobs(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + payload = await self._logprobs_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return payload diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 257f3337..3c35bebb 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -260,7 +260,7 @@ class VLLMServer(APIServer): return [], [] - async def get_logprobs(self, **kwargs) -> Dict[str, Any]: + async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]: """ Fetch normalized prompt logprobs from vLLM /generate with optional top-k. @@ -315,26 +315,19 @@ class VLLMServer(APIServer): request_data["top_p"] = 1.0 request_data.setdefault("max_tokens", 1) - # Keep semaphore behavior consistent with other server calls. - split = request_data.pop("split", "train") - sem = self.sem if split == "train" else self.eval_sem - while not self.server_healthy: - await asyncio.sleep(1) - - async with sem: - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.config.base_url.replace('/v1', '')}/generate", - json=request_data, - headers=( - {"Authorization": f"Bearer {self.config.api_key}"} - if self.config.api_key - else {} - ), - timeout=aiohttp.ClientTimeout(total=self.config.timeout), - ) as response: - response.raise_for_status() - results = await response.json() + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=request_data, + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + ) as response: + response.raise_for_status() + results = await response.json() raw_prompt_logprobs = results.get("prompt_logprobs") if raw_prompt_logprobs is None: diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 00d1e9b9..9d7221e7 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -292,6 +292,30 @@ async def test_get_logprobs_normalized_schema(mock_server): assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs +@pytest.mark.asyncio +async def test_get_logprobs_messages_passthrough(mock_server): + """ManagedServer.get_logprobs converts messages and passes prompt through.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + messages = [{"role": "user", "content": "Hello"}] + expected_prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(expected_prompt) + + async def _mock_get_logprobs(**kwargs): + assert kwargs.get("prompt") == expected_prompt + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[t] for t in prompt_tokens], + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens], + } + + mock_server.get_logprobs = _mock_get_logprobs + payload = await managed.get_logprobs(messages=messages, top_k=1) + + assert payload["prompt_tokens"] == prompt_tokens + assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens) + assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens) + + @pytest.mark.asyncio async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py new file mode 100644 index 00000000..2da50b42 --- /dev/null +++ b/atroposlib/tests/test_server_logprobs.py @@ -0,0 +1,103 @@ +"""Tests for get_logprobs wrappers and server-manager routing.""" + +import pytest + +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + AsyncSemWithAdaptiveWeight, +) +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class _FakeAPIServer(APIServer): + def __init__(self, config: APIServerConfig): + super().__init__(config=config, reasoning_config=None) + self.calls = 0 + self.last_kwargs = None + + async def check_server_status_task(self, chat_completion: bool = True): + self.server_healthy = True + + async def _chat_completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _tokens_and_logprobs_completion_wrapper(self, **kwargs): + raise NotImplementedError + + async def _get_logprobs_wrapper(self, **kwargs): + self.calls += 1 + self.last_kwargs = kwargs + prompt = kwargs.get("prompt", "") + prompt_tokens = [ord(c) for c in prompt] + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[t] for t in prompt_tokens], + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens], + } + + +class _FakeRoutedServer: + def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True): + self.name = name + self.server_healthy = healthy + self.sem = AsyncSemWithAdaptiveWeight(4) + self.eval_sem = AsyncSemWithAdaptiveWeight(4) + self.sem._value = train_slots + self.eval_sem._value = eval_slots + self.calls = 0 + + async def get_logprobs(self, **kwargs): + self.calls += 1 + return { + "server": self.name, + "prompt_tokens": [1], + "prompt_topk_token_ids": [[1]], + "prompt_topk_logprobs": [[-0.1]], + } + + +@pytest.mark.asyncio +async def test_apiserver_get_logprobs_train_eval_wrappers(): + cfg = APIServerConfig( + model_name="test-model", + base_url="", + health_check=False, + ) + server = _FakeAPIServer(cfg) + + train_out = await server.get_logprobs(prompt="hi", split="train") + assert train_out["prompt_tokens"] == [ord("h"), ord("i")] + assert server.calls == 1 + assert server.last_kwargs["model"] == "test-model" + assert len(server.request_timings) == 1 + assert len(server.attempts_list) == 1 + assert len(server.eval_request_timings) == 0 + assert len(server.eval_attempts_list) == 0 + + eval_out = await server.get_logprobs(prompt="ok", split="eval") + assert eval_out["prompt_tokens"] == [ord("o"), ord("k")] + assert server.calls == 2 + assert len(server.eval_request_timings) == 1 + assert len(server.eval_attempts_list) == 1 + + +@pytest.mark.asyncio +async def test_server_manager_get_logprobs_routes_to_most_available_server(): + s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True) + s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True) + s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False) + + manager = ServerManager.__new__(ServerManager) + manager.servers = [s1, s2, s3] + + out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train") + assert out_train["server"] == "s2" + assert s2.calls == 1 + + out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") + assert out_eval["server"] == "s1" + assert s1.calls == 1 From b166c3a9d926a5c54f81aab0fbbc8310493433b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 20:48:13 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/tests/test_server_logprobs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 2da50b42..8cbd84ad 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -41,7 +41,9 @@ class _FakeAPIServer(APIServer): class _FakeRoutedServer: - def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True): + def __init__( + self, name: str, train_slots: int, eval_slots: int, healthy: bool = True + ): self.name = name self.server_healthy = healthy self.sem = AsyncSemWithAdaptiveWeight(4) From 4d8e9b8086d510b56e4f77b9518274027562e19f Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Thu, 26 Feb 2026 21:48:56 +0300 Subject: [PATCH 16/18] fix: use sys.executable instead of hardcoded "python" in tests Tests that launch the API server via subprocess used a hardcoded "python" command which fails on systems where only "python3" is available (e.g. macOS). Using sys.executable ensures the same interpreter running pytest is used for subprocesses. Fixes 36 test errors on macOS environments. --- atroposlib/tests/api_test_utils.py | 3 ++- atroposlib/tests/test_api_compression.py | 3 ++- atroposlib/tests/test_api_messages_handling.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/atroposlib/tests/api_test_utils.py b/atroposlib/tests/api_test_utils.py index c88d51f2..200ada8c 100644 --- a/atroposlib/tests/api_test_utils.py +++ b/atroposlib/tests/api_test_utils.py @@ -1,4 +1,5 @@ import subprocess +import sys import time import requests @@ -16,7 +17,7 @@ def launch_api_for_testing(max_wait_for_api: int = 10) -> subprocess.Popen: # Use subprocess instead of multiprocessing to avoid inheriting pytest args api_proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_compression.py b/atroposlib/tests/test_api_compression.py index 25cbdfe0..0202c3ad 100644 --- a/atroposlib/tests/test_api_compression.py +++ b/atroposlib/tests/test_api_compression.py @@ -5,6 +5,7 @@ import json import os import signal import subprocess +import sys import time import pytest @@ -27,7 +28,7 @@ def wait_for_api_server(max_wait=10): def api_server(): proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", diff --git a/atroposlib/tests/test_api_messages_handling.py b/atroposlib/tests/test_api_messages_handling.py index 0f7ac922..228be3ea 100644 --- a/atroposlib/tests/test_api_messages_handling.py +++ b/atroposlib/tests/test_api_messages_handling.py @@ -5,6 +5,7 @@ Tests for API server message handling, particularly for SFT (Supervised Fine-Tun import os import signal import subprocess +import sys import time import pytest @@ -30,7 +31,7 @@ def api_server(): # Start the API server as a subprocess proc = subprocess.Popen( [ - "python", + sys.executable, "-m", "atroposlib.cli.run_api", "--host", From eb500993616f960df19e751ba243d63bc7e55c82 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 5 Mar 2026 17:01:52 -0500 Subject: [PATCH 17/18] test_get_logprobs_input_ids_only_passthrough --- atroposlib/tests/test_managed_server.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 9d7221e7..624b46cd 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -316,6 +316,30 @@ async def test_get_logprobs_messages_passthrough(mock_server): assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens) +@pytest.mark.asyncio +async def test_get_logprobs_input_ids_only_passthrough(mock_server): + """ManagedServer.get_logprobs supports input_ids-only without requiring prompt.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + input_ids = [10, 20, 30] + + async def _mock_get_logprobs(**kwargs): + assert "input_ids" in kwargs + assert kwargs["input_ids"] == input_ids + assert kwargs.get("prompt") is None + return { + "prompt_tokens": input_ids, + "prompt_topk_token_ids": [[t] for t in input_ids], + "prompt_topk_logprobs": [[-0.1] for _ in input_ids], + } + + mock_server.get_logprobs = _mock_get_logprobs + payload = await managed.get_logprobs(input_ids=input_ids, top_k=1) + + assert payload["prompt_tokens"] == input_ids + assert payload["prompt_topk_token_ids"] == [[10], [20], [30]] + assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]] + + @pytest.mark.asyncio async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" From 880bb4a632a72c33553d11a70b42020dae8b9aeb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:44:14 +0000 Subject: [PATCH 18/18] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black-pre-commit-mirror: 26.1.0 → 26.3.0](https://github.com/psf/black-pre-commit-mirror/compare/26.1.0...26.3.0) - [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.4...v0.15.5) - [github.com/codespell-project/codespell: v2.4.1 → v2.4.2](https://github.com/codespell-project/codespell/compare/v2.4.1...v2.4.2) --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24c59eef..10fee35f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,12 @@ repos: - id: check-merge-conflict - repo: https://github.com/psf/black-pre-commit-mirror - rev: 26.1.0 + rev: 26.3.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.4 + rev: v0.15.5 hooks: - id: ruff args: ["--select=I", "--fix"] # Only run import-related rules @@ -32,7 +32,7 @@ repos: args: ['--baseline', '.secrets.baseline'] - repo: https://github.com/codespell-project/codespell - rev: v2.4.1 + rev: v2.4.2 hooks: - id: codespell args: ["--skip", "*.csv,*.html", "-L", "te,ans,sems,lsat,anc,strokin,lod,nam,ques,unparseable,rouge,oll,managin,expressio,re-declare"]