mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
init commit
This commit is contained in:
parent
887a94374c
commit
b9291aa29f
5 changed files with 357 additions and 0 deletions
|
|
@ -529,6 +529,78 @@ class ManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch logprobs via wrapped server with a normalized trainer-agnostic schema.
|
||||
|
||||
Supported inputs:
|
||||
- prompt
|
||||
- messages (converted to prompt)
|
||||
- input_ids
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens
|
||||
- sequence_token_ids
|
||||
- sequence_logprobs
|
||||
- sequence_topk_token_ids
|
||||
- sequence_topk_logprobs
|
||||
- finish_reasons
|
||||
"""
|
||||
request_kwargs = kwargs.copy()
|
||||
messages = request_kwargs.pop("messages", None)
|
||||
|
||||
if messages is not None:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
request_kwargs["prompt"] = prompt
|
||||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
# Reuse tracked context in non-tree mode when possible.
|
||||
if (
|
||||
not self.track_tree
|
||||
and self.tokenizer is not None
|
||||
and "input_ids" not in request_kwargs
|
||||
and prompt is not None
|
||||
):
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
request_kwargs["input_ids"] = self._compute_input_ids(prompt, extending_node)
|
||||
|
||||
if hasattr(self.server, "get_logprobs"):
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
else:
|
||||
# Backwards-compatible fallback for harness/test doubles.
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
|
||||
payload = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"sequence_token_ids": output_tokens_list,
|
||||
"sequence_logprobs": output_logprobs_list,
|
||||
"sequence_topk_token_ids": [
|
||||
[[tok] for tok in seq] for seq in output_tokens_list
|
||||
],
|
||||
"sequence_topk_logprobs": [
|
||||
[[lp] for lp in seq] for seq in output_logprobs_list
|
||||
],
|
||||
"finish_reasons": finish_reasons,
|
||||
}
|
||||
|
||||
# Normalize required keys if provider omitted top-k arrays.
|
||||
if "sequence_topk_token_ids" not in payload:
|
||||
payload["sequence_topk_token_ids"] = [
|
||||
[[tok] for tok in seq] for seq in payload["sequence_token_ids"]
|
||||
]
|
||||
if "sequence_topk_logprobs" not in payload:
|
||||
payload["sequence_topk_logprobs"] = [
|
||||
[[lp] for lp in seq] for seq in payload["sequence_logprobs"]
|
||||
]
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
@ -640,6 +712,25 @@ class DummyManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Return interface-compatible dummy logprob payload.
|
||||
|
||||
This keeps interface parity with ManagedServer while making it explicit
|
||||
that results are placeholders and not suitable for training.
|
||||
"""
|
||||
n = int(kwargs.get("n", 1))
|
||||
seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)]
|
||||
seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)]
|
||||
return {
|
||||
"prompt_tokens": [],
|
||||
"sequence_token_ids": seq_ids,
|
||||
"sequence_logprobs": seq_lps,
|
||||
"sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids],
|
||||
"sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps],
|
||||
"finish_reasons": ["stop"] * n,
|
||||
}
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue