mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge conflict commit
This commit is contained in:
commit
f198c1738e
13 changed files with 579 additions and 14 deletions
|
|
@ -704,6 +704,39 @@ class ManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch prompt logprobs via wrapped server with a normalized schema.
|
||||
|
||||
Supported inputs:
|
||||
- prompt
|
||||
- messages (converted to prompt)
|
||||
- input_ids
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt_tokens
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
request_kwargs = kwargs.copy()
|
||||
messages = request_kwargs.pop("messages", None)
|
||||
|
||||
if messages is not None:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
request_kwargs["prompt"] = prompt
|
||||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
"Strict mode requires backend prompt logprobs."
|
||||
)
|
||||
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
return payload
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
@ -815,6 +848,15 @@ class DummyManagedServer:
|
|||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Dummy managed server does not provide real prompt logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DummyManagedServer does not support get_logprobs in strict mode. "
|
||||
"Use a backend with real prompt logprob support."
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue