mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fixing comments
This commit is contained in:
parent
51088ac24d
commit
1eeb31065f
6 changed files with 21 additions and 81 deletions
|
|
@ -36,6 +36,7 @@ async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|||
|
||||
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
|
||||
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences
|
||||
- ✅ **Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers
|
||||
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
|
||||
- ✅ **Branching Support**: Handles n>1 completions naturally
|
||||
|
|
@ -628,8 +629,8 @@ Fetch logprobs with a normalized schema that is backend-agnostic.
|
|||
```
|
||||
|
||||
**Notes:**
|
||||
- Strict mode: backend must provide real prompt top-k arrays with aligned shapes.
|
||||
- Missing fields or shape mismatches fail fast with explicit errors.
|
||||
- Strict mode: backend must provide real prompt top-k arrays.
|
||||
- Missing keys should be treated as backend contract violations.
|
||||
|
||||
#### `def get_state() -> Dict[str, Any]`
|
||||
Get the current state of tracked sequences.
|
||||
|
|
@ -736,7 +737,7 @@ With this flag set, `managed_server()` will return a `DummyManagedServer` that:
|
|||
- Provides the same interface as `ManagedServer`
|
||||
- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays)
|
||||
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
|
||||
- Implements `get_logprobs(...)` with the same normalized keys as `ManagedServer`, but placeholder values (not suitable for training)
|
||||
- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload)
|
||||
|
||||
### When to Use DummyManagedServer
|
||||
|
||||
|
|
@ -769,8 +770,8 @@ async with self.server.managed_server() as managed:
|
|||
print(node.tokens[:5]) # placeholder values
|
||||
print(node.logprobs[:5]) # placeholder values
|
||||
|
||||
payload = await managed.get_logprobs(messages=messages, n=4)
|
||||
print(payload.keys()) # normalized schema keys
|
||||
# Strict mode: get_logprobs is not available on DummyManagedServer
|
||||
# and will raise NotImplementedError.
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ For automatic token and logprob tracking, see the [ManagedServer Guide](MANAGED_
|
|||
|
||||
### Normalized `get_logprobs` API
|
||||
|
||||
`ManagedServer` and server backends now expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends:
|
||||
`ManagedServer` and server backends expose a normalized `get_logprobs(...)` interface so callers can consume a single schema across backends:
|
||||
|
||||
- `prompt_tokens`
|
||||
- `prompt_topk_token_ids`
|
||||
- `prompt_topk_logprobs`
|
||||
|
||||
Strict mode: backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast.
|
||||
Backends must return real prompt top-k arrays. Missing keys or malformed shapes fail fast.
|
||||
|
||||
## Reasoning Model Support
|
||||
|
||||
|
|
|
|||
|
|
@ -553,18 +553,6 @@ class ManagedServer:
|
|||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
# Reuse tracked context in non-tree mode when possible.
|
||||
if (
|
||||
not self.track_tree
|
||||
and self.tokenizer is not None
|
||||
and "input_ids" not in request_kwargs
|
||||
and prompt is not None
|
||||
):
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
request_kwargs["input_ids"] = self._compute_input_ids(
|
||||
prompt, extending_node
|
||||
)
|
||||
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
|
|
@ -572,41 +560,8 @@ class ManagedServer:
|
|||
)
|
||||
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
self._validate_prompt_logprob_payload(payload)
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None:
|
||||
required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs")
|
||||
missing = [k for k in required if k not in payload]
|
||||
if missing:
|
||||
raise ValueError(f"get_logprobs response missing required keys: {missing}")
|
||||
|
||||
prompt_tokens = payload["prompt_tokens"]
|
||||
token_ids = payload["prompt_topk_token_ids"]
|
||||
logprobs = payload["prompt_topk_logprobs"]
|
||||
|
||||
if not isinstance(prompt_tokens, list):
|
||||
raise ValueError("prompt_tokens must be a list[int].")
|
||||
if not isinstance(token_ids, list) or not isinstance(logprobs, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens):
|
||||
raise ValueError("prompt_topk arrays must align with prompt_tokens length.")
|
||||
|
||||
for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)):
|
||||
if not isinstance(tok_row, list) or not isinstance(lp_row, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(tok_row) != len(lp_row):
|
||||
raise ValueError(
|
||||
f"prompt_topk row mismatch at position {idx}: "
|
||||
f"{len(tok_row)} token ids vs {len(lp_row)} logprobs."
|
||||
)
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -355,19 +355,6 @@ class ServerManager:
|
|||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Prompt logprobs are prompt-level; n-splitting does not change prompt arrays.
|
||||
results = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
n_to_use = min(total_n, self.max_n_completions)
|
||||
kwargs["n"] = n_to_use
|
||||
results.append(self.get_logprobs(**kwargs))
|
||||
total_n -= n_to_use
|
||||
results = await asyncio.gather(*results)
|
||||
return results[0]
|
||||
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
|
|
|
|||
|
|
@ -284,14 +284,18 @@ class VLLMServer(APIServer):
|
|||
top_k = max(1, top_k)
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
from_prompt_text = False
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
from_prompt_text = True
|
||||
|
||||
# Check for double BOS token.
|
||||
# Only normalize BOS for tokenizer-encoded prompt text.
|
||||
if (
|
||||
from_prompt_text
|
||||
and
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
|
|
@ -306,6 +310,11 @@ class VLLMServer(APIServer):
|
|||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}}
|
||||
request_data["prompt_logprobs"] = top_k
|
||||
request_data.update(kwargs)
|
||||
# This API is prompt-logprobs focused, not generation-focused.
|
||||
request_data["n"] = 1
|
||||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
|
||||
# Keep semaphore behavior consistent with other server calls.
|
||||
split = request_data.pop("split", "train")
|
||||
|
|
|
|||
|
|
@ -293,24 +293,12 @@ async def test_get_logprobs_normalized_schema(mock_server):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_strict_mode_rejects_misaligned_payload(mock_server):
|
||||
"""ManagedServer.get_logprobs fails fast on malformed prompt top-k payload."""
|
||||
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
|
||||
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||
# Missing one row on purpose -> misaligned with prompt length
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens[:-1]],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
|
||||
with pytest.raises(ValueError, match="align with prompt_tokens length"):
|
||||
with pytest.raises(NotImplementedError, match="does not implement get_logprobs"):
|
||||
await managed.get_logprobs(prompt=prompt, n=1)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue