init commit

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 11:32:09 -05:00
parent 887a94374c
commit b9291aa29f
5 changed files with 357 additions and 0 deletions

View file

@ -529,6 +529,78 @@ class ManagedServer:
else:
self.current_nodes.clear()
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Fetch logprobs via wrapped server with a normalized trainer-agnostic schema.
Supported inputs:
- prompt
- messages (converted to prompt)
- input_ids
Returns:
Dict with:
- prompt_tokens
- sequence_token_ids
- sequence_logprobs
- sequence_topk_token_ids
- sequence_topk_logprobs
- finish_reasons
"""
request_kwargs = kwargs.copy()
messages = request_kwargs.pop("messages", None)
if messages is not None:
prompt = self._convert_messages_to_prompt(messages)
request_kwargs["prompt"] = prompt
else:
prompt = request_kwargs.get("prompt")
# Reuse tracked context in non-tree mode when possible.
if (
not self.track_tree
and self.tokenizer is not None
and "input_ids" not in request_kwargs
and prompt is not None
):
extending_node = self._find_extending_node(prompt)
request_kwargs["input_ids"] = self._compute_input_ids(prompt, extending_node)
if hasattr(self.server, "get_logprobs"):
payload = await self.server.get_logprobs(**request_kwargs)
else:
# Backwards-compatible fallback for harness/test doubles.
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
payload = {
"prompt_tokens": prompt_tokens,
"sequence_token_ids": output_tokens_list,
"sequence_logprobs": output_logprobs_list,
"sequence_topk_token_ids": [
[[tok] for tok in seq] for seq in output_tokens_list
],
"sequence_topk_logprobs": [
[[lp] for lp in seq] for seq in output_logprobs_list
],
"finish_reasons": finish_reasons,
}
# Normalize required keys if provider omitted top-k arrays.
if "sequence_topk_token_ids" not in payload:
payload["sequence_topk_token_ids"] = [
[[tok] for tok in seq] for seq in payload["sequence_token_ids"]
]
if "sequence_topk_logprobs" not in payload:
payload["sequence_topk_logprobs"] = [
[[lp] for lp in seq] for seq in payload["sequence_logprobs"]
]
return payload
class DummyManagedServer:
"""
@ -640,6 +712,25 @@ class DummyManagedServer:
else:
self.current_nodes.clear()
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Return interface-compatible dummy logprob payload.
This keeps interface parity with ManagedServer while making it explicit
that results are placeholders and not suitable for training.
"""
n = int(kwargs.get("n", 1))
seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)]
seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)]
return {
"prompt_tokens": [],
"sequence_token_ids": seq_ids,
"sequence_logprobs": seq_lps,
"sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids],
"sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps],
"finish_reasons": ["stop"] * n,
}
class ManagedServerAdapter:
"""