Add dummy openai managed server

This commit is contained in:
Dakota 2026-02-04 15:16:36 -06:00
parent 462abbebf7
commit 10f651289c
4 changed files with 235 additions and 11 deletions

View file

@ -511,6 +511,117 @@ class ManagedServer:
self.current_nodes.clear()
class DummyManagedServer:
"""
A simple managed server wrapper for OpenAI endpoints that don't support token IDs/logprobs.
Uses fixed placeholder values for tokens and logprobs. NOT suitable for training.
"""
# Fixed dummy values
DUMMY_TOKENS = [1, 2, 3]
DUMMY_MASKED_TOKENS = [-100, 2, 3]
DUMMY_LOGPROBS = [-0.5, -0.5, -0.5]
def __init__(
self,
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
):
self.server = server
self.track_tree = track_tree
# tokenizer is accepted but ignored - we don't tokenize anything
if track_tree:
self.sequences: Dict[str, SequenceNode] = {}
else:
self.current_nodes: List[SequenceNode] = []
def _messages_to_text(self, messages: List[Dict[str, str]]) -> str:
"""Convert messages to simple text format."""
return "\n\n".join([f"{m['role']}:{m['content']}" for m in messages])
def _create_dummy_node(
self,
full_text: str,
finish_reason: str = "stop",
) -> SequenceNode:
"""Create a sequence node with fixed dummy values."""
return SequenceNode(
full_text=full_text,
tokens=self.DUMMY_TOKENS,
masked_tokens=self.DUMMY_MASKED_TOKENS,
logprobs=self.DUMMY_LOGPROBS,
metadata={"finish_reason": finish_reason, "dummy_tokens": True},
)
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""Make a chat completion call and track with dummy tokens."""
messages = kwargs.get("messages", [])
response = await self.server.chat_completion(**kwargs)
for choice in response.choices:
completion_content = choice.message.content or ""
# Append assistant response to messages for full_text
all_messages = messages + [
{"role": "assistant", "content": completion_content}
]
full_text = self._messages_to_text(all_messages)
node = self._create_dummy_node(
full_text=full_text,
finish_reason=choice.finish_reason or "stop",
)
if self.track_tree:
self.sequences[node.full_text] = node
else:
self.current_nodes.append(node)
return response
async def completion(self, **kwargs) -> Completion:
"""Make a completion call and track with dummy tokens."""
prompt = kwargs.get("prompt", "")
response = await self.server.completion(**kwargs)
for choice in response.choices:
completion_text = choice.text or ""
full_text = f"{prompt}{completion_text}"
node = self._create_dummy_node(
full_text=full_text,
finish_reason=choice.finish_reason or "stop",
)
if self.track_tree:
self.sequences[node.full_text] = node
else:
self.current_nodes.append(node)
return response
def get_state(self) -> Dict[str, Any]:
"""Get the current state of tracked sequences."""
if self.track_tree:
return {
"sequences": self.sequences.copy(),
"tree": self.sequences.copy(),
}
else:
return {"nodes": self.current_nodes.copy()}
def reset(self):
"""Clear all tracked sequences."""
if self.track_tree:
self.sequences.clear()
else:
self.current_nodes.clear()
class ManagedServerAdapter:
"""
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.