mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Add dummy openai managed server
This commit is contained in:
parent
462abbebf7
commit
10f651289c
4 changed files with 235 additions and 11 deletions
|
|
@ -511,6 +511,117 @@ class ManagedServer:
|
|||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
A simple managed server wrapper for OpenAI endpoints that don't support token IDs/logprobs.
|
||||
|
||||
Uses fixed placeholder values for tokens and logprobs. NOT suitable for training.
|
||||
"""
|
||||
|
||||
# Fixed dummy values
|
||||
DUMMY_TOKENS = [1, 2, 3]
|
||||
DUMMY_MASKED_TOKENS = [-100, 2, 3]
|
||||
DUMMY_LOGPROBS = [-0.5, -0.5, -0.5]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
self.server = server
|
||||
self.track_tree = track_tree
|
||||
# tokenizer is accepted but ignored - we don't tokenize anything
|
||||
|
||||
if track_tree:
|
||||
self.sequences: Dict[str, SequenceNode] = {}
|
||||
else:
|
||||
self.current_nodes: List[SequenceNode] = []
|
||||
|
||||
def _messages_to_text(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Convert messages to simple text format."""
|
||||
return "\n\n".join([f"{m['role']}:{m['content']}" for m in messages])
|
||||
|
||||
def _create_dummy_node(
|
||||
self,
|
||||
full_text: str,
|
||||
finish_reason: str = "stop",
|
||||
) -> SequenceNode:
|
||||
"""Create a sequence node with fixed dummy values."""
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
tokens=self.DUMMY_TOKENS,
|
||||
masked_tokens=self.DUMMY_MASKED_TOKENS,
|
||||
logprobs=self.DUMMY_LOGPROBS,
|
||||
metadata={"finish_reason": finish_reason, "dummy_tokens": True},
|
||||
)
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""Make a chat completion call and track with dummy tokens."""
|
||||
messages = kwargs.get("messages", [])
|
||||
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
|
||||
for choice in response.choices:
|
||||
completion_content = choice.message.content or ""
|
||||
# Append assistant response to messages for full_text
|
||||
all_messages = messages + [
|
||||
{"role": "assistant", "content": completion_content}
|
||||
]
|
||||
full_text = self._messages_to_text(all_messages)
|
||||
|
||||
node = self._create_dummy_node(
|
||||
full_text=full_text,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
)
|
||||
|
||||
if self.track_tree:
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
self.current_nodes.append(node)
|
||||
|
||||
return response
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""Make a completion call and track with dummy tokens."""
|
||||
prompt = kwargs.get("prompt", "")
|
||||
|
||||
response = await self.server.completion(**kwargs)
|
||||
|
||||
for choice in response.choices:
|
||||
completion_text = choice.text or ""
|
||||
full_text = f"{prompt}{completion_text}"
|
||||
|
||||
node = self._create_dummy_node(
|
||||
full_text=full_text,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
)
|
||||
|
||||
if self.track_tree:
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
self.current_nodes.append(node)
|
||||
|
||||
return response
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get the current state of tracked sequences."""
|
||||
if self.track_tree:
|
||||
return {
|
||||
"sequences": self.sequences.copy(),
|
||||
"tree": self.sequences.copy(),
|
||||
}
|
||||
else:
|
||||
return {"nodes": self.current_nodes.copy()}
|
||||
|
||||
def reset(self):
|
||||
"""Clear all tracked sequences."""
|
||||
if self.track_tree:
|
||||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue