diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 692f3a10..148444b1 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -2,7 +2,7 @@ ## Overview -`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. +`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. **Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection. @@ -36,7 +36,7 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions - ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0 -- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly +- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers - ✅ **Multi-turn Support**: Automatically handles conversation extensions - ✅ **Branching Support**: Handles n>1 completions naturally - ✅ **Clean API**: Simple context manager pattern @@ -609,6 +609,31 @@ Intercept completion call and track sequences. **Side Effects:** - Tracks sequences in internal storage +#### `async def get_logprobs(**kwargs) -> Dict[str, Any]` +Fetch logprobs with a normalized schema that is backend-agnostic. + +**Args (common):** +- `messages` or `prompt` or `input_ids` +- `n`: Number of sampled sequences +- `max_tokens` +- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop` + +**Returns (normalized):** +```python +{ + "prompt_tokens": List[int], + "sequence_token_ids": List[List[int]], # [seq][pos] + "sequence_logprobs": List[List[float]], # [seq][pos] + "sequence_topk_token_ids": List[List[List[int]]], # [seq][pos][k] + "sequence_topk_logprobs": List[List[List[float]]], # [seq][pos][k] + "finish_reasons": List[Any], +} +``` + +**Notes:** +- If the backend only returns sampled-token logprobs, ManagedServer synthesizes `k=1` singleton top-k arrays. +- This method is for transport/interface consistency; richer top-k depends on backend support. + #### `def get_state() -> Dict[str, Any]` Get the current state of tracked sequences. @@ -712,11 +737,9 @@ export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1 With this flag set, `managed_server()` will return a `DummyManagedServer` that: - Provides the same interface as `ManagedServer` -- Returns **fixed placeholder values** for tokens and logprobs: - - `tokens`: `[1, 2, 3]` - - `masked_tokens`: `[-100, 2, 3]` - - `logprobs`: `[-0.5, -0.5, -0.5]` +- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays) - Uses simple text formatting for `full_text`: `role:content` joined by `\n\n` +- Implements `get_logprobs(...)` with the same normalized keys as `ManagedServer`, but placeholder values (not suitable for training) ### When to Use DummyManagedServer @@ -746,8 +769,11 @@ async with self.server.managed_server() as managed: # nodes contain placeholder token data - DO NOT use for training for node in nodes: print(node.full_text) # Real completion text - print(node.tokens) # [1, 2, 3] - placeholder! - print(node.logprobs) # [-0.5, -0.5, -0.5] - placeholder! + print(node.tokens[:5]) # placeholder values + print(node.logprobs[:5]) # placeholder values + + payload = await managed.get_logprobs(messages=messages, n=4) + print(payload.keys()) # normalized schema keys ``` ### Recommendation diff --git a/atroposlib/envs/server_handling/README.md b/atroposlib/envs/server_handling/README.md index 447b1127..23d6b5d0 100644 --- a/atroposlib/envs/server_handling/README.md +++ b/atroposlib/envs/server_handling/README.md @@ -8,6 +8,19 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_ > **Note:** OpenAI endpoints do not support token IDs/logprobs required for ManagedServer. Set `ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1` to use a placeholder implementation for testing/evaluation. See [OpenAI Endpoint Limitations](MANAGED_SERVER.md#openai-endpoint-limitations) for details. +### Normalized `get_logprobs` API + +`ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends: + +- `prompt_tokens` +- `sequence_token_ids` +- `sequence_logprobs` +- `sequence_topk_token_ids` +- `sequence_topk_logprobs` +- `finish_reasons` + +For backends that only expose sampled-token logprobs, top-k arrays are synthesized with `k=1` for interface compatibility. + ## Reasoning Model Support The `ReasoningConfig` class enables support for reasoning/thinking models across different providers.