mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
prompt logprobs
This commit is contained in:
parent
e98100e5f6
commit
439b9b129b
7 changed files with 73 additions and 138 deletions
|
|
@ -531,7 +531,7 @@ class ManagedServer:
|
|||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch logprobs via wrapped server with a normalized trainer-agnostic schema.
|
||||
Fetch prompt logprobs via wrapped server with a normalized schema.
|
||||
|
||||
Supported inputs:
|
||||
- prompt
|
||||
|
|
@ -541,11 +541,8 @@ class ManagedServer:
|
|||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens
|
||||
- sequence_token_ids
|
||||
- sequence_logprobs
|
||||
- sequence_topk_token_ids
|
||||
- sequence_topk_logprobs
|
||||
- finish_reasons
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
request_kwargs = kwargs.copy()
|
||||
messages = request_kwargs.pop("messages", None)
|
||||
|
|
@ -574,31 +571,24 @@ class ManagedServer:
|
|||
# Backwards-compatible fallback for harness/test doubles.
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
_output_tokens_list,
|
||||
_output_logprobs_list,
|
||||
_finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
|
||||
payload = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"sequence_token_ids": output_tokens_list,
|
||||
"sequence_logprobs": output_logprobs_list,
|
||||
"sequence_topk_token_ids": [
|
||||
[[tok] for tok in seq] for seq in output_tokens_list
|
||||
],
|
||||
"sequence_topk_logprobs": [
|
||||
[[lp] for lp in seq] for seq in output_logprobs_list
|
||||
],
|
||||
"finish_reasons": finish_reasons,
|
||||
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[1.0] for _ in prompt_tokens],
|
||||
}
|
||||
|
||||
# Normalize required keys if provider omitted top-k arrays.
|
||||
if "sequence_topk_token_ids" not in payload:
|
||||
payload["sequence_topk_token_ids"] = [
|
||||
[[tok] for tok in seq] for seq in payload["sequence_token_ids"]
|
||||
# Normalize required keys if provider omitted prompt top-k arrays.
|
||||
if "prompt_topk_token_ids" not in payload:
|
||||
payload["prompt_topk_token_ids"] = [
|
||||
[tok] for tok in payload.get("prompt_tokens", [])
|
||||
]
|
||||
if "sequence_topk_logprobs" not in payload:
|
||||
payload["sequence_topk_logprobs"] = [
|
||||
[[lp] for lp in seq] for seq in payload["sequence_logprobs"]
|
||||
if "prompt_topk_logprobs" not in payload:
|
||||
payload["prompt_topk_logprobs"] = [
|
||||
[1.0] for _ in payload.get("prompt_tokens", [])
|
||||
]
|
||||
|
||||
return payload
|
||||
|
|
@ -722,15 +712,12 @@ class DummyManagedServer:
|
|||
that results are placeholders and not suitable for training.
|
||||
"""
|
||||
n = int(kwargs.get("n", 1))
|
||||
seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)]
|
||||
seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)]
|
||||
prompt_tokens = self.DUMMY_TOKENS[:]
|
||||
return {
|
||||
"prompt_tokens": [],
|
||||
"sequence_token_ids": seq_ids,
|
||||
"sequence_logprobs": seq_lps,
|
||||
"sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids],
|
||||
"sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps],
|
||||
"finish_reasons": ["stop"] * n,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens],
|
||||
"finish_reasons": ["stop"] * n, # Retained for compatibility in callers.
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue