From 439b9b129b50636c6082d81660c81357e8b7fa7f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 3 Mar 2026 21:56:11 -0500 Subject: [PATCH] prompt logprobs --- .../envs/server_handling/MANAGED_SERVER.md | 12 ++-- atroposlib/envs/server_handling/README.md | 8 +-- .../envs/server_handling/managed_server.py | 53 ++++++---------- .../envs/server_handling/server_baseline.py | 35 ++++------- .../envs/server_handling/server_manager.py | 31 ++------- .../envs/server_handling/vllm_server.py | 63 +++++++------------ atroposlib/tests/test_managed_server.py | 9 +-- 7 files changed, 73 insertions(+), 138 deletions(-) diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 148444b1..6b739a83 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -622,17 +622,15 @@ Fetch logprobs with a normalized schema that is backend-agnostic. ```python { "prompt_tokens": List[int], - "sequence_token_ids": List[List[int]], # [seq][pos] - "sequence_logprobs": List[List[float]], # [seq][pos] - "sequence_topk_token_ids": List[List[List[int]]], # [seq][pos][k] - "sequence_topk_logprobs": List[List[List[float]]], # [seq][pos][k] - "finish_reasons": List[Any], + "prompt_topk_token_ids": List[List[int]], # [pos][k] + "prompt_topk_logprobs": List[List[float]], # [pos][k] + "finish_reasons": List[Any], # optional compatibility field } ``` **Notes:** -- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton top-k arrays. -- This method is for transport/interface consistency; richer top-k depends on backend support. +- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton prompt top-k arrays. +- This method is prompt-focused; richer top-k depends on backend support. #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index 23d6b5d0..f19ad272 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -13,13 +13,11 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ `ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: - `prompt_tokens` -- `sequence_token_ids` -- `sequence_logprobs` -- `sequence_topk_token_ids` -- `sequence_topk_logprobs` +- `prompt_topk_token_ids` +- `prompt_topk_logprobs` - `finish_reasons` -For backends that only expose sampled-token logprobs, top-k arrays are synthesized with `k=1` for interface compatibility. +For backends that only expose sampled-token logprobs, prompt top-k arrays are synthesized with `k=1` for interface compatibility. ## Reasoning Model Support diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 38b94489..f2231ab7 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -531,7 +531,7 @@ class ManagedServer: async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Fetch logprobs via wrapped server with a normalized trainer-agnostic schema. + Fetch prompt logprobs via wrapped server with a normalized schema. Supported inputs: - prompt @@ -541,11 +541,8 @@ class ManagedServer: Returns: Dict with: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ request_kwargs = kwargs.copy() messages = request_kwargs.pop("messages", None) @@ -574,31 +571,24 @@ class ManagedServer: # Backwards-compatible fallback for harness/test doubles. ( prompt_tokens, - output_tokens_list, - output_logprobs_list, - finish_reasons, + _output_tokens_list, + _output_logprobs_list, + _finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**request_kwargs) payload = { "prompt_tokens": prompt_tokens, - "sequence_token_ids": output_tokens_list, - "sequence_logprobs": output_logprobs_list, - "sequence_topk_token_ids": [ - [[tok] for tok in seq] for seq in output_tokens_list - ], - "sequence_topk_logprobs": [ - [[lp] for lp in seq] for seq in output_logprobs_list - ], - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + "prompt_topk_logprobs": [[1.0] for _ in prompt_tokens], } - # Normalize required keys if provider omitted top-k arrays. - if "sequence_topk_token_ids" not in payload: - payload["sequence_topk_token_ids"] = [ - [[tok] for tok in seq] for seq in payload["sequence_token_ids"] + # Normalize required keys if provider omitted prompt top-k arrays. + if "prompt_topk_token_ids" not in payload: + payload["prompt_topk_token_ids"] = [ + [tok] for tok in payload.get("prompt_tokens", []) ] - if "sequence_topk_logprobs" not in payload: - payload["sequence_topk_logprobs"] = [ - [[lp] for lp in seq] for seq in payload["sequence_logprobs"] + if "prompt_topk_logprobs" not in payload: + payload["prompt_topk_logprobs"] = [ + [1.0] for _ in payload.get("prompt_tokens", []) ] return payload @@ -722,15 +712,12 @@ class DummyManagedServer: that results are placeholders and not suitable for training. """ n = int(kwargs.get("n", 1)) - seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)] - seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)] + prompt_tokens = self.DUMMY_TOKENS[:] return { - "prompt_tokens": [], - "sequence_token_ids": seq_ids, - "sequence_logprobs": seq_lps, - "sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids], - "sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps], - "finish_reasons": ["stop"] * n, + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + "prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens], + "finish_reasons": ["stop"] * n, # Retained for compatibility in callers. } diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 281f4fbd..c387f5dd 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -641,40 +641,31 @@ class APIServer(ABC): async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Trainer-agnostic logprob API with normalized output schema. + Trainer-agnostic prompt-logprob API with normalized output schema. This default implementation is built from `tokens_and_logprobs_completion` - and returns sampled-token logprobs (top-k singleton per position). + and returns prompt-side singleton top-k values. Returns: Dict with: - prompt_tokens: List[int] - - sequence_token_ids: List[List[int]] - - sequence_logprobs: List[List[float]] - - sequence_topk_token_ids: List[List[List[int]]] - - sequence_topk_logprobs: List[List[List[float]]] - - finish_reasons: List[Any] + - prompt_topk_token_ids: List[List[int]] + - prompt_topk_logprobs: List[List[float]] """ ( prompt_tokens, - output_tokens_list, - output_logprobs_list, - finish_reasons, + _output_tokens_list, + _output_logprobs_list, + _finish_reasons, ) = await self.tokens_and_logprobs_completion(**kwargs) - topk_token_ids = [ - [[token_id] for token_id in seq_tokens] for seq_tokens in output_tokens_list - ] - topk_logprobs = [ - [[logprob] for logprob in seq_logprobs] - for seq_logprobs in output_logprobs_list - ] + # Fallback path does not have true prompt-logprobs, so we provide + # interface-compatible singleton values for each prompt token. + prompt_topk_token_ids = [[token_id] for token_id in prompt_tokens] + prompt_topk_logprobs = [[1.0] for _ in prompt_tokens] return { "prompt_tokens": prompt_tokens, - "sequence_token_ids": output_tokens_list, - "sequence_logprobs": output_logprobs_list, - "sequence_topk_token_ids": topk_token_ids, - "sequence_topk_logprobs": topk_logprobs, - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, } diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 6d0a90d5..4c548aac 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -348,19 +348,16 @@ class ServerManager: async def get_logprobs(self, **kwargs) -> dict: """ - Route normalized get_logprobs requests to the most available server. + Route normalized prompt-logprob requests to the most available server. Returns a normalized dict with: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ n = kwargs.get("n", 1) if n > self.max_n_completions: - # Split into multiple requests and merge sequence-level outputs. + # Prompt logprobs are prompt-level; n-splitting does not change prompt arrays. results = [] total_n = n while total_n > 0: @@ -369,25 +366,7 @@ class ServerManager: results.append(self.get_logprobs(**kwargs)) total_n -= n_to_use results = await asyncio.gather(*results) - merged = { - "prompt_tokens": results[0]["prompt_tokens"], - "sequence_token_ids": [], - "sequence_logprobs": [], - "sequence_topk_token_ids": [], - "sequence_topk_logprobs": [], - "finish_reasons": [], - } - for result in results: - merged["sequence_token_ids"].extend(result["sequence_token_ids"]) - merged["sequence_logprobs"].extend(result["sequence_logprobs"]) - merged["sequence_topk_token_ids"].extend( - result["sequence_topk_token_ids"] - ) - merged["sequence_topk_logprobs"].extend( - result["sequence_topk_logprobs"] - ) - merged["finish_reasons"].extend(result["finish_reasons"]) - return merged + return results[0] is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 11ba87e8..e0091ead 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -262,7 +262,7 @@ class VLLMServer(APIServer): async def get_logprobs(self, **kwargs) -> Dict[str, Any]: """ - Fetch normalized logprobs from vLLM /generate with optional top-k. + Fetch normalized prompt logprobs from vLLM /generate with optional top-k. Args: top_k / top_logprobs: Optional number of logprobs per position. @@ -272,11 +272,8 @@ class VLLMServer(APIServer): Returns: Normalized dict: - prompt_tokens - - sequence_token_ids - - sequence_logprobs - - sequence_topk_token_ids - - sequence_topk_logprobs - - finish_reasons + - prompt_topk_token_ids + - prompt_topk_logprobs """ assert ( kwargs.get("prompt", None) is not None @@ -306,10 +303,8 @@ class VLLMServer(APIServer): kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") kwargs.pop("model", None) - request_data = { - "prompt": {"prompt_token_ids": prompt_tokens}, - "logprobs": top_k, - } + request_data = {"prompt": {"prompt_token_ids": prompt_tokens}} + request_data["prompt_logprobs"] = top_k request_data.update(kwargs) # Keep semaphore behavior consistent with other server calls. @@ -333,40 +328,30 @@ class VLLMServer(APIServer): response.raise_for_status() results = await response.json() - sequence_topk_token_ids: List[List[List[int]]] = [] - sequence_topk_logprobs: List[List[List[float]]] = [] - sequence_token_ids: List[List[int]] = [] - sequence_logprobs: List[List[float]] = [] - finish_reasons: List[Any] = [] + raw_prompt_logprobs = results.get("prompt_logprobs") + if raw_prompt_logprobs is None: + raise ValueError( + "vLLM /generate response missing 'prompt_logprobs'. " + "Ensure backend supports prompt logprobs." + ) - for token_logprobs_seq, finish_reason in zip( - results["logprobs"], results["finish_reasons"] - ): - seq_topk_token_ids: List[List[int]] = [] - seq_topk_logprobs: List[List[float]] = [] - seq_token_ids: List[int] = [] - seq_logprobs: List[float] = [] + # Handle either direct [position] payloads or [sequence][position] payloads. + if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list): + prompt_entries = raw_prompt_logprobs[0] + else: + prompt_entries = raw_prompt_logprobs - for token_logprobs_entry in token_logprobs_seq: - topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry) - seq_topk_token_ids.append(topk_ids) - seq_topk_logprobs.append(topk_lps) - seq_token_ids.append(topk_ids[0] if topk_ids else -1) - seq_logprobs.append(topk_lps[0] if topk_lps else 0.0) - - sequence_topk_token_ids.append(seq_topk_token_ids) - sequence_topk_logprobs.append(seq_topk_logprobs) - sequence_token_ids.append(seq_token_ids) - sequence_logprobs.append(seq_logprobs) - finish_reasons.append(finish_reason) + prompt_topk_token_ids: List[List[int]] = [] + prompt_topk_logprobs: List[List[float]] = [] + for entry in prompt_entries: + topk_ids, topk_lps = self._normalize_topk_entry(entry) + prompt_topk_token_ids.append(topk_ids) + prompt_topk_logprobs.append(topk_lps) return { "prompt_tokens": prompt_tokens, - "sequence_token_ids": sequence_token_ids, - "sequence_logprobs": sequence_logprobs, - "sequence_topk_token_ids": sequence_topk_token_ids, - "sequence_topk_logprobs": sequence_topk_logprobs, - "finish_reasons": finish_reasons, + "prompt_topk_token_ids": prompt_topk_token_ids, + "prompt_topk_logprobs": prompt_topk_logprobs, } diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index d8ea2aed..ce644a05 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -267,7 +267,7 @@ async def test_bos_token_handling(mock_server): @pytest.mark.asyncio async def test_get_logprobs_normalized_schema(mock_server): - """ManagedServer.get_logprobs returns normalized schema.""" + """ManagedServer.get_logprobs returns normalized prompt schema.""" managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) prompt = "Hello" @@ -286,11 +286,8 @@ async def test_get_logprobs_normalized_schema(mock_server): payload = await managed.get_logprobs(prompt=prompt, n=1) assert payload["prompt_tokens"] == prompt_tokens - assert payload["sequence_token_ids"] == output_tokens - assert payload["sequence_logprobs"] == output_logprobs - assert payload["finish_reasons"] == ["stop"] - assert payload["sequence_topk_token_ids"] == [[[ord("!")], [ord("?")]]] - assert payload["sequence_topk_logprobs"] == [[[-0.1], [-0.2]]] + assert payload["prompt_topk_token_ids"] == [[tok] for tok in prompt_tokens] + assert payload["prompt_topk_logprobs"] == [[1.0] for _ in prompt_tokens] @pytest.mark.asyncio