diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index c1358dc6..5941576e 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -62,6 +62,7 @@ class ManagedServer: server: APIServer, tokenizer: Optional[Any] = None, track_tree: bool = False, + tool_parser: Optional[str] = None, ): """ Initialize the managed server. @@ -73,10 +74,17 @@ class ManagedServer: track_tree: If True, maintains a tree structure with parent-child links (for multi-turn RL with per-step advantages). If False (default), maintains a simple list of current nodes that updates in-place. + tool_parser: Optional vLLM tool parser name (e.g. "hermes", "llama3_json", + "mistral", etc.). If provided, enables tool call support in + chat_completion(). The parser handles extraction of structured + tool calls from raw model output. See + ToolParserManager.list_registered() for available parsers. """ self.server = server self.tokenizer = tokenizer self.track_tree = track_tree + self._tool_parser_name = tool_parser + self._translator = None # Lazy init # Initialize storage based on mode if track_tree: @@ -107,19 +115,57 @@ class ManagedServer: ) self.tokenizer = None - def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: + def _get_translator(self): + """Lazily create the ToolCallTranslator when first needed. + + Returns None if tool_parser was not specified or if vLLM is not + installed (the translator will warn on creation in that case). + """ + if self._translator is None and self._tool_parser_name and self.tokenizer: + try: + from atroposlib.envs.server_handling.tool_call_translator import ( + ToolCallTranslator, + ) + + self._translator = ToolCallTranslator( + tokenizer=self.tokenizer, + parser_name=self._tool_parser_name, + ) + except Exception as e: + warnings.warn( + f"Failed to create ToolCallTranslator: {e}. " + "Tool call parsing will be disabled.", + stacklevel=2, + ) + self._tool_parser_name = None # Don't retry + return None + return self._translator + + def _convert_messages_to_prompt( + self, + messages: List[Dict[str, str]], + tools: Optional[List[dict]] = None, + ) -> str: """ Convert chat messages to prompt text using tokenizer's chat template. Args: messages: List of message dicts with 'role' and 'content' + tools: Optional list of tool definitions (OpenAI format). Passed to + apply_chat_template() so the template can inject tool defs + into the system prompt. Returns: Formatted prompt string """ + # If tools are active and we have a translator, convert any assistant + # messages with tool_calls back to raw text first + if tools and self._get_translator(): + messages = self._get_translator().convert_messages_for_template(messages) + if self.tokenizer is None: # Fallback: simple concatenation - return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages]) if hasattr(self.tokenizer, "apply_chat_template"): # Only add generation prompt if last message is not from assistant @@ -127,13 +173,19 @@ class ManagedServer: len(messages) == 0 or messages[-1].get("role") != "assistant" ) + # Build kwargs + template_kwargs = { + "tokenize": False, + "add_generation_prompt": add_generation_prompt, + } + if tools: + template_kwargs["tools"] = tools + # Use the tokenizer's chat template - return self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=add_generation_prompt - ) + return self.tokenizer.apply_chat_template(messages, **template_kwargs) else: # Fallback for tokenizers without chat template - return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages]) def _debug_requests_enabled(self) -> bool: """Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1.""" @@ -268,15 +320,35 @@ class ManagedServer: Internally converts to prompt, calls tokens_and_logprobs_completion, tracks the sequence, and reconstructs a ChatCompletion response. + Supports tool calling when a tool_parser was provided at init: + - Accepts `tools` and `tool_choice` kwargs + - Converts inbound assistant tool_call messages to raw text + - Parses outbound model output for tool calls + - Returns ChatCompletion with proper tool_calls in choices + - Preserves raw text in tracked nodes (tool parsing is response-only) + Args: - **kwargs: Standard chat completion kwargs (messages, n, etc.) + **kwargs: Standard chat completion kwargs (messages, n, max_tokens, + temperature, tools, tool_choice, etc.) Returns: - ChatCompletion response + ChatCompletion response (with tool_calls if detected) """ - # Get input text + # Extract tool-related kwargs + tools = kwargs.pop("tools", None) + tool_choice = kwargs.pop("tool_choice", None) + has_tools = bool(tools) and self._get_translator() is not None + + # Default tool_choice to "auto" if tools provided + if has_tools and tool_choice is None: + tool_choice = "auto" + + # Get input text — passes tools for template rendering and + # handles reconstruction of inbound tool_call messages messages = kwargs.get("messages", []) - prompt = self._convert_messages_to_prompt(messages) + prompt = self._convert_messages_to_prompt( + messages, tools=tools if has_tools else None + ) # Handle parent node and extending logic based on mode if self.track_tree: @@ -296,11 +368,12 @@ class ManagedServer: msg_count = len(messages) prompt_preview = prompt.replace("\n", "\\n")[:600] logger.debug( - "[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s", + "[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s tools=%s", msg_count, completion_kwargs.get("n"), completion_kwargs.get("max_tokens"), completion_kwargs.get("temperature"), + bool(tools), ) logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview) @@ -336,15 +409,18 @@ class ManagedServer: else: finish_reason = finish_reason_raw - # Decode completion text + # Decode completion text — use skip_special_tokens=False when + # tools are active so tags aren't stripped if self.tokenizer is not None: completion_text = self.tokenizer.decode( - output_tokens, skip_special_tokens=True + output_tokens, + skip_special_tokens=not has_tools, ) else: completion_text = "".join([chr(t) for t in output_tokens if t > 31]) - # Create and store sequence node + # Create and store sequence node — always uses the raw text, + # tool parsing only affects the ChatCompletion response node = self._create_sequence_node( input_text=prompt, parent_node=parent_node, @@ -373,14 +449,50 @@ class ManagedServer: # New context - append to list self.current_nodes.append(node) + # Parse tool calls from raw output if tools are active + tool_calls_parsed = None + content_for_response = completion_text + if has_tools and tool_choice != "none": + translator = self._get_translator() + content_for_response, tool_calls_parsed, finish_reason = ( + translator.parse_model_output( + raw_text=completion_text, + tool_choice=( + tool_choice if isinstance(tool_choice, str) else "auto" + ), + tools=tools, + ) + ) + # Build choice + message_kwargs = { + "content": content_for_response, + "role": "assistant", + } + # Note: openai's ChatCompletionMessage model handles tool_calls + # but we can't pass them through the constructor easily. We'll + # attach them after construction if needed. choice = Choice( finish_reason=finish_reason, index=i, - message=ChatCompletionMessage( - content=completion_text, role="assistant" - ), + message=ChatCompletionMessage(**message_kwargs), ) + + # Attach tool_calls to the message if present + if tool_calls_parsed: + choice.message.tool_calls = [ + # Convert vLLM ToolCall to openai ToolCall format + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls_parsed + ] + choices.append(choice) # Construct ChatCompletion response diff --git a/atroposlib/envs/server_handling/managed_server_proxy.py b/atroposlib/envs/server_handling/managed_server_proxy.py new file mode 100644 index 00000000..0b7bbf0f --- /dev/null +++ b/atroposlib/envs/server_handling/managed_server_proxy.py @@ -0,0 +1,623 @@ +""" +OpenAI-compatible chat completions proxy over ManagedServer. + +Exposes /{uuid}/v1/chat/completions and related endpoints so that external +environment microservices can interact with ManagedServer via standard +OpenAI API during multi-step rollouts. + +Each UUID maps to a session containing a ManagedServer instance. Tool call +parsing uses vLLM's parsers directly. The ManagedServer always stores raw +text — tool call translation only affects the HTTP wire format. + +Uses ServerManager for routing across multiple backend servers (load balancing, +health checks, etc.) — same infrastructure as the rest of atropos. + +Usage: + # Standalone with JSON config + python -m atroposlib.envs.server_handling.managed_server_proxy \\ + --config servers.json \\ + --port 9100 + + # Or mount into existing FastAPI app + from atroposlib.envs.server_handling.managed_server_proxy import create_app + app = create_app(server_manager, tokenizer, model_name) + +servers.json example: + { + "model_name": "Qwen/Qwen3-4B", + "servers": [ + {"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""}, + {"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""} + ] + } +""" + +import logging +import time +import uuid as uuid_lib +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from atroposlib.envs.server_handling.managed_server import ( + DummyManagedServer, + ManagedServer, +) +from atroposlib.envs.server_handling.openai_server import OpenAIServer +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.server_manager import ServerManager +from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Request / Response models +# --------------------------------------------------------------------------- + + +class ChatMessage(BaseModel): + role: str + content: Optional[str] = None + tool_calls: Optional[List[dict]] = None + tool_call_id: Optional[str] = None + name: Optional[str] = None + + +class ChatCompletionProxyRequest(BaseModel): + messages: List[ChatMessage] + max_tokens: int = 1024 + temperature: float = 1.0 + n: int = 1 + stop: Optional[List[str]] = None + tools: Optional[List[dict]] = None + tool_choice: Optional[Any] = None # "auto", "none", "required", or dict + # TODO: top_p, frequency_penalty, presence_penalty, seed, response_format, + # logprobs, top_logprobs — pass through to backend when implemented + + +class SessionCreateRequest(BaseModel): + tool_parser: str = "hermes" + track_tree: bool = False + # Pin to a specific backend server by its base_url. In production, + # the caller gets their assigned server from the atropos API and + # passes it here. If omitted, falls back to picking the server + # with the most open semaphore slots (fine for dev/testing). + base_url: Optional[str] = None + + +class SessionCreateResponse(BaseModel): + uuid: str + model_name: str + tool_parser: str + base_url: Optional[str] = None # Which backend was selected + created_at: float + + +class RenderResponse(BaseModel): + prompt_text: str + token_ids: List[int] + num_tokens: int + + +# --------------------------------------------------------------------------- +# Session state +# --------------------------------------------------------------------------- + + +@dataclass +class SessionState: + uuid: str + managed_server: ManagedServer + translator: ToolCallTranslator + model_name: str + base_url: Optional[str] = None # Which backend server this session is pinned to + created_at: float = field(default_factory=time.time) + last_accessed: float = field(default_factory=time.time) + + +# --------------------------------------------------------------------------- +# OpenAI error format +# --------------------------------------------------------------------------- + + +def openai_error( + status_code: int, message: str, error_type: str = "invalid_request_error" +) -> JSONResponse: + return JSONResponse( + status_code=status_code, + content={ + "error": { + "message": message, + "type": error_type, + "code": status_code, + } + }, + ) + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + + +def create_app( + server_manager: ServerManager, + tokenizer: Any, + model_name: str = "unknown", +) -> FastAPI: + """Create the proxy FastAPI app. + + Args: + server_manager: ServerManager instance managing one or more backend + servers (VLLMServer, SGLangServer, etc.). Used to pick the most + available server when creating sessions. + tokenizer: HuggingFace tokenizer for the model. + model_name: Model name to report in responses. + + Returns: + FastAPI app with all endpoints registered. + """ + + app = FastAPI(title="ManagedServer OpenAI Proxy") + sessions: Dict[str, SessionState] = {} + + # -- helpers -- + + def _get_session(session_uuid: str) -> SessionState: + session = sessions.get(session_uuid) + if session is None: + raise HTTPException( + status_code=404, detail=f"Session {session_uuid} not found" + ) + session.last_accessed = time.time() + return session + + def _get_server_base_url(server) -> Optional[str]: + """Get the base_url from a server's config, if available.""" + if hasattr(server, "config") and hasattr(server.config, "base_url"): + return server.config.base_url + return None + + def _select_server(base_url: Optional[str] = None): + """Pick a server from the manager. + + Args: + base_url: If provided, pin to the server with this base_url. + Raises 404 if no server matches. + If None, picks the most available server (mirrors + ServerManager.managed_server() logic). + + Returns: + Selected APIServer instance. + """ + if base_url is not None: + # Pin to specific server by base_url + for server in server_manager.servers: + server_url = _get_server_base_url(server) + if server_url and server_url.rstrip("/") == base_url.rstrip("/"): + return server + # No match — list available URLs in error + available = [ + _get_server_base_url(s) + for s in server_manager.servers + if _get_server_base_url(s) + ] + raise HTTPException( + status_code=404, + detail=f"No server with base_url '{base_url}'. Available: {available}", + ) + + # Auto-select most available + most_available_idx = 0 + most_available_slots = -1 + for i, server in enumerate(server_manager.servers): + if not server.server_healthy: + continue + if server.sem._value > most_available_slots: + most_available_idx = i + most_available_slots = server.sem._value + return server_manager.servers[most_available_idx] + + def _render_prompt( + messages: List[Dict[str, Any]], + tools: Optional[List[dict]] = None, + translator: Optional[ToolCallTranslator] = None, + ) -> str: + """Render messages to prompt text via chat template. + + If a translator is provided, converts OpenAI tool_call messages + back to raw text first. + """ + # Convert messages to dicts + msg_dicts = [] + for msg in messages: + if isinstance(msg, BaseModel): + msg_dicts.append(msg.model_dump(exclude_none=True)) + else: + msg_dicts.append(msg) + + # Reconstruct raw text for assistant tool_call messages + if translator: + msg_dicts = translator.convert_messages_for_template(msg_dicts) + + # Build kwargs for apply_chat_template + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + } + if tools: + template_kwargs["tools"] = tools + + return tokenizer.apply_chat_template(msg_dicts, **template_kwargs) + + def _build_openai_response( + choices_data: List[dict], + model: str, + ) -> dict: + """Build an OpenAI ChatCompletion response dict.""" + return { + "id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": choices_data, + } + + # -- endpoints -- + + @app.get("/health") + async def health(): + healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy) + return { + "status": "ok", + "model": model_name, + "sessions": len(sessions), + "servers": len(server_manager.servers), + "healthy_servers": healthy_servers, + } + + @app.get("/v1/models") + async def list_models(): + return { + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "atropos", + } + ], + } + + @app.post("/setup") + async def setup(request: Request): + """Receive server configuration from ServerManager. + + Accepts the same JSON format as the standalone --config file. + Replaces the current server_manager's servers with the new config. + Called by ServerManager at startup to push its config to the proxy. + + Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]} + """ + config = await request.json() + + new_configs = [] + new_model = config.get("model_name", model_name) + for srv in config.get("servers", []): + new_configs.append( + APIServerConfig( + model_name=new_model, + base_url=srv["base_url"], + api_key=srv.get("api_key", ""), + server_type=srv.get("server_type", "vllm"), + num_max_requests_at_once=srv.get("num_max_requests_at_once", 512), + num_requests_for_eval=srv.get("num_requests_for_eval", 64), + timeout=srv.get("timeout", 1200), + tokenizer_name=config.get("tokenizer_name", "none"), + ) + ) + + if new_configs: + new_manager = ServerManager(configs=new_configs) + server_manager.servers = new_manager.servers + logger.info("Setup: replaced servers with %d new configs", len(new_configs)) + + return { + "status": "ok", + "servers": len(server_manager.servers), + "model_name": new_model, + } + + @app.get("/servers") + async def list_servers(): + """List available backend servers. + + Useful for discovery/debugging. In production, server allocation + is managed by the atropos API — the environment gets told which + server to use and passes that base_url to POST /sessions/create. + """ + server_list = [] + for i, server in enumerate(server_manager.servers): + url = _get_server_base_url(server) + healthy = getattr(server, "server_healthy", True) + server_list.append( + { + "index": i, + "base_url": url, + "healthy": healthy, + "model_name": ( + getattr(server.config, "model_name", model_name) + if hasattr(server, "config") + else model_name + ), + "server_type": ( + getattr(server.config, "server_type", "unknown") + if hasattr(server, "config") + else "unknown" + ), + } + ) + return {"servers": server_list} + + @app.post("/sessions/create", response_model=SessionCreateResponse) + async def create_session(request: SessionCreateRequest): + session_uuid = str(uuid_lib.uuid4()) + + # Pick server — pinned to base_url if specified, otherwise most available + selected_server = _select_server(base_url=request.base_url) + selected_url = _get_server_base_url(selected_server) + + # Use DummyManagedServer for OpenAI endpoints (no logprobs support) + if isinstance(selected_server, OpenAIServer): + logger.info( + "Session %s using DummyManagedServer (OpenAI endpoint). " + "Token IDs and logprobs will be placeholders.", + session_uuid, + ) + managed = DummyManagedServer( + server=selected_server, + tokenizer=tokenizer, + track_tree=request.track_tree, + ) + else: + managed = ManagedServer( + server=selected_server, + tokenizer=tokenizer, + track_tree=request.track_tree, + tool_parser=request.tool_parser, + ) + + # Translator kept for the render endpoint (prompt preview) + translator = ToolCallTranslator( + tokenizer=tokenizer, + parser_name=request.tool_parser, + ) + session = SessionState( + uuid=session_uuid, + managed_server=managed, + translator=translator, + model_name=model_name, + base_url=selected_url, + ) + sessions[session_uuid] = session + + return SessionCreateResponse( + uuid=session_uuid, + model_name=model_name, + tool_parser=request.tool_parser, + base_url=selected_url, + created_at=session.created_at, + ) + + @app.get("/sessions") + async def list_sessions(): + return { + "sessions": [ + { + "uuid": s.uuid, + "model_name": s.model_name, + "base_url": s.base_url, + "created_at": s.created_at, + "last_accessed": s.last_accessed, + "num_nodes": len( + s.managed_server.current_nodes + if hasattr(s.managed_server, "current_nodes") + else s.managed_server.sequences + ), + } + for s in sessions.values() + ] + } + + @app.post("/{session_uuid}/v1/chat/completions") + async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest): + session = _get_session(session_uuid) + managed = session.managed_server + + if not request.messages: + return openai_error(400, "messages must not be empty") + + # Convert pydantic messages to dicts + messages = [msg.model_dump(exclude_none=True) for msg in request.messages] + + # Build kwargs — ManagedServer.chat_completion() handles all tool + # call logic internally (template rendering, inbound reconstruction, + # outbound parsing, skip_special_tokens) + completion_kwargs = { + "messages": messages, + "n": request.n, + "max_tokens": request.max_tokens, + "temperature": request.temperature, + } + if request.stop: + completion_kwargs["stop"] = request.stop + if request.tools: + completion_kwargs["tools"] = request.tools + if request.tool_choice is not None: + completion_kwargs["tool_choice"] = request.tool_choice + + try: + result = await managed.chat_completion(**completion_kwargs) + except Exception as e: + logger.exception("Completion failed") + return openai_error( + 500, f"Completion failed: {e}", error_type="server_error" + ) + + # Convert ChatCompletion to JSON-serializable response + choices = [] + for choice in result.choices: + choice_data = { + "index": choice.index, + "message": { + "role": choice.message.role, + "content": choice.message.content, + }, + "finish_reason": choice.finish_reason, + } + if choice.message.tool_calls: + choice_data["message"]["tool_calls"] = choice.message.tool_calls + + choices.append(choice_data) + + return _build_openai_response(choices, model_name) + + @app.post("/{session_uuid}/v1/chat/completions/render") + async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest): + session = _get_session(session_uuid) + + try: + prompt_text = _render_prompt( + messages=request.messages, + tools=request.tools, + translator=session.translator, + ) + except Exception as e: + return openai_error(400, f"Failed to render prompt: {e}") + + token_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + + return RenderResponse( + prompt_text=prompt_text, + token_ids=token_ids, + num_tokens=len(token_ids), + ) + + @app.get("/{session_uuid}/nodes") + async def get_nodes(session_uuid: str): + session = _get_session(session_uuid) + state = session.managed_server.get_state() + + if session.managed_server.track_tree: + nodes = list(state.get("sequences", {}).values()) + else: + nodes = state.get("nodes", []) + + return { + "nodes": [node.model_dump() for node in nodes], + } + + @app.delete("/{session_uuid}") + async def delete_session(session_uuid: str): + session = sessions.pop(session_uuid, None) + if session is None: + raise HTTPException( + status_code=404, detail=f"Session {session_uuid} not found" + ) + session.managed_server.reset() + return {"status": "deleted", "uuid": session_uuid} + + # -- exception handler for OpenAI-style errors -- + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException): + return openai_error(exc.status_code, exc.detail) + + # NOTE on concurrent access: each UUID is meant to represent a single + # rollout session used by one caller at a time. If you send concurrent + # requests to the same UUID, the ManagedServer's node extension logic + # may get confused because it does prefix matching on current_nodes. + # Don't do that. UUIDs are cheap, make a new one. + + return app + + +# --------------------------------------------------------------------------- +# Standalone entrypoint +# --------------------------------------------------------------------------- + + +def main(): + import argparse + import json + + import uvicorn + from transformers import AutoTokenizer + + parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy") + parser.add_argument( + "--config", + required=True, + help="Path to JSON config file with server definitions", + ) + parser.add_argument("--port", type=int, default=9100, help="Proxy port") + parser.add_argument("--host", default="0.0.0.0", help="Proxy host") + args = parser.parse_args() + + # Load config + with open(args.config) as f: + config = json.load(f) + + model_name = config["model_name"] + tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name)) + + # Build APIServerConfigs from the JSON + server_configs = [] + for srv in config["servers"]: + server_configs.append( + APIServerConfig( + model_name=model_name, + base_url=srv["base_url"], + api_key=srv.get("api_key", ""), + server_type=srv.get("server_type", "vllm"), + num_max_requests_at_once=srv.get("num_max_requests_at_once", 512), + num_requests_for_eval=srv.get("num_requests_for_eval", 64), + timeout=srv.get("timeout", 1200), + tokenizer_name=config.get("tokenizer_name", "none"), + ) + ) + + server_manager = ServerManager(configs=server_configs) + + app = create_app( + server_manager=server_manager, + tokenizer=tokenizer, + model_name=model_name, + ) + + print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}") + print(f" Model: {model_name}") + print(f" Backends: {len(server_configs)} server(s)") + for i, cfg in enumerate(server_configs): + print(f" [{i}] {cfg.base_url} ({cfg.server_type})") + print() + print("Endpoints:") + print(" POST /sessions/create") + print(" POST /{uuid}/v1/chat/completions") + print(" POST /{uuid}/v1/chat/completions/render") + print(" GET /{uuid}/nodes") + print(" DELETE /{uuid}") + print(" GET /sessions") + print(" GET /v1/models") + print(" GET /health") + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/atroposlib/envs/server_handling/proxy_client.py b/atroposlib/envs/server_handling/proxy_client.py new file mode 100644 index 00000000..51e99122 --- /dev/null +++ b/atroposlib/envs/server_handling/proxy_client.py @@ -0,0 +1,292 @@ +""" +Client that talks to the ManagedServer OpenAI proxy over HTTP. + +Implements the same interface as ManagedServer so it can be used as a +drop-in replacement via ServerManager.managed_server(use_proxy=True). + +The proxy handles all the token tracking, tool call parsing, and sequence +management. This client just ferries requests/responses over HTTP and +reconstructs the SequenceNode objects from the JSON. +""" + +import logging +import time +import uuid as uuid_lib +from typing import Any, Dict, List, Optional + +import aiohttp +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.completion import Completion # noqa: F401 — used in type hint + +from atroposlib.envs.server_handling.managed_server import SequenceNode + +logger = logging.getLogger(__name__) + + +class ProxyManagedServer: + """Client that talks to the ManagedServer OpenAI proxy. + + Same interface as ManagedServer — chat_completion(), completion(), + get_state(), reset(). But instead of doing token tracking in-process, + delegates everything to the proxy over HTTP. + + Created by ServerManager.managed_server(use_proxy=True). + + Example: + async with server_manager.managed_server(use_proxy=True) as managed: + # Same API as regular ManagedServer + resp = await managed.chat_completion( + messages=[{"role": "user", "content": "Hello"}], + n=4, max_tokens=100, temperature=1.0, + ) + state = managed.get_state() + nodes = state["nodes"] + + # Extra: get URL for external apps to use directly + url = managed.get_url() + # → "http://proxy:9100/{uuid}/v1" + """ + + def __init__( + self, + proxy_url: str, + session_uuid: str, + model_name: str = "unknown", + base_url: Optional[str] = None, + ): + """ + Args: + proxy_url: Base URL of the proxy (e.g. "http://localhost:9100") + session_uuid: UUID of the session on the proxy. + model_name: Model name (for response objects). + base_url: The backend server this session is pinned to. + """ + self.proxy_url = proxy_url.rstrip("/") + self.session_uuid = session_uuid + self.model_name = model_name + self.base_url = base_url + + # Cache for nodes (populated by get_state) + self._cached_nodes: Optional[List[SequenceNode]] = None + + def get_url(self) -> str: + """Get the OpenAI-compatible API URL for this session. + + External apps can use this URL with any OpenAI client: + client = openai.OpenAI(base_url=managed.get_url()) + client.chat.completions.create(messages=..., tools=...) + + Returns: + URL like "http://proxy:9100/{uuid}/v1" + """ + return f"{self.proxy_url}/{self.session_uuid}/v1" + + async def _post(self, path: str, json: dict, timeout: int = 300) -> dict: + """Make a POST request to the proxy.""" + url = f"{self.proxy_url}{path}" + async with aiohttp.ClientSession() as session: + async with session.post( + url, json=json, timeout=aiohttp.ClientTimeout(total=timeout) + ) as resp: + data = await resp.json() + if resp.status != 200: + error_msg = data.get("error", {}).get("message", str(data)) + raise RuntimeError( + f"Proxy request failed ({resp.status}): {error_msg}" + ) + return data + + async def _get(self, path: str, timeout: int = 30) -> dict: + """Make a GET request to the proxy.""" + url = f"{self.proxy_url}{path}" + async with aiohttp.ClientSession() as session: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=timeout) + ) as resp: + data = await resp.json() + if resp.status != 200: + error_msg = data.get("error", {}).get("message", str(data)) + raise RuntimeError( + f"Proxy request failed ({resp.status}): {error_msg}" + ) + return data + + async def _delete(self, path: str, timeout: int = 30) -> dict: + """Make a DELETE request to the proxy.""" + url = f"{self.proxy_url}{path}" + async with aiohttp.ClientSession() as session: + async with session.delete( + url, timeout=aiohttp.ClientTimeout(total=timeout) + ) as resp: + data = await resp.json() + return data + + async def chat_completion(self, **kwargs) -> ChatCompletion: + """Send a chat completion request through the proxy. + + Same interface as ManagedServer.chat_completion(). + The proxy handles template rendering, tool call parsing, + and token/logprob tracking. + """ + # Convert messages to serializable format + messages = kwargs.get("messages", []) + serialized_messages = [] + for msg in messages: + if isinstance(msg, dict): + serialized_messages.append(msg) + else: + serialized_messages.append(dict(msg)) + + body = { + "messages": serialized_messages, + "max_tokens": kwargs.get("max_tokens", 1024), + "temperature": kwargs.get("temperature", 1.0), + "n": kwargs.get("n", 1), + } + if kwargs.get("stop"): + body["stop"] = kwargs["stop"] + if kwargs.get("tools"): + body["tools"] = kwargs["tools"] + if kwargs.get("tool_choice") is not None: + body["tool_choice"] = kwargs["tool_choice"] + + data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body) + + # Reconstruct ChatCompletion from proxy response + choices = [] + for choice_data in data.get("choices", []): + msg = choice_data.get("message", {}) + choice = Choice( + finish_reason=choice_data.get("finish_reason", "stop"), + index=choice_data.get("index", 0), + message=ChatCompletionMessage( + content=msg.get("content"), + role=msg.get("role", "assistant"), + ), + ) + choices.append(choice) + + return ChatCompletion( + id=data.get("id", str(uuid_lib.uuid4())), + created=data.get("created", int(time.time())), + model=data.get("model", self.model_name), + object="chat.completion", + choices=choices, + ) + + async def completion(self, **kwargs) -> Completion: + """Send a completion request through the proxy. + + Note: the proxy's chat/completions endpoint is the primary interface. + For raw completions, the proxy renders the prompt via chat template + internally. If you're calling this, you probably want chat_completion() + instead. + """ + # For completion() calls, we'd need a /completions endpoint on the proxy. + # Currently the proxy only exposes chat/completions. For now, raise + # a clear error. + raise NotImplementedError( + "ProxyManagedServer.completion() is not supported. " + "Use chat_completion() instead — the proxy handles template " + "rendering internally." + ) + + def get_state(self) -> Dict[str, Any]: + """Get the current state synchronously from cache. + + Call fetch_state() first to populate from the proxy, or use + the nodes returned by this method after a chat_completion() call. + + Returns: + Dict with 'nodes': List[SequenceNode] + """ + if self._cached_nodes is not None: + return {"nodes": self._cached_nodes} + return {"nodes": []} + + async def fetch_state(self) -> Dict[str, Any]: + """Fetch current state from the proxy (async). + + Returns: + Dict with 'nodes': List[SequenceNode] + """ + data = await self._get(f"/{self.session_uuid}/nodes") + nodes = [] + for node_data in data.get("nodes", []): + nodes.append(SequenceNode(**node_data)) + self._cached_nodes = nodes + return {"nodes": nodes} + + def reset(self): + """Clear cached state. The actual cleanup happens in __aexit__.""" + self._cached_nodes = None + + async def cleanup(self): + """Delete the session on the proxy.""" + try: + await self._delete(f"/{self.session_uuid}") + except Exception as e: + logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}") + + # -- context manager support -- + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Fetch final state before cleanup so callers can still access nodes + try: + await self.fetch_state() + except Exception: + pass + await self.cleanup() + + +async def create_proxy_session( + proxy_url: str, + base_url: Optional[str] = None, + tool_parser: str = "hermes", + track_tree: bool = False, + model_name: str = "unknown", +) -> ProxyManagedServer: + """Create a new session on the proxy and return a ProxyManagedServer. + + Args: + proxy_url: Base URL of the proxy (e.g. "http://localhost:9100"). + base_url: Pin to a specific backend server. In production, this + comes from the atropos API's server allocation. + tool_parser: vLLM tool parser name (default: "hermes"). + track_tree: Whether to use tree mode for tracking. + model_name: Model name for response objects. + + Returns: + ProxyManagedServer instance ready to use. + """ + body = { + "tool_parser": tool_parser, + "track_tree": track_tree, + } + if base_url: + body["base_url"] = base_url + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{proxy_url.rstrip('/')}/sessions/create", + json=body, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + data = await resp.json() + if resp.status != 200: + error_msg = data.get("error", {}).get("message", str(data)) + raise RuntimeError(f"Failed to create proxy session: {error_msg}") + + return ProxyManagedServer( + proxy_url=proxy_url, + session_uuid=data["uuid"], + model_name=data.get("model_name", model_name), + base_url=data.get("base_url"), + ) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index e76dea32..f3d6ff3b 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -41,6 +41,14 @@ class ServerManagerConfig(BaseModel): "This is to help load balance servers." ), ) + proxy_url: Optional[str] = Field( + default=None, + description=( + "URL of the ManagedServer OpenAI proxy (e.g. 'http://localhost:9100'). " + "When set, managed_server(use_proxy=True) routes through this proxy. " + "Can also be set via ATROPOS_PROXY_URL environment variable." + ), + ) class ServerManager: @@ -52,9 +60,18 @@ class ServerManager: testing=False, max_n_completions=8, reasoning_config: Optional[ReasoningConfig] = None, + proxy_url: Optional[str] = None, + use_proxy: bool = False, + tool_parser: Optional[str] = None, ): self.max_n_completions = max_n_completions self.reasoning_config = reasoning_config + # Proxy config — when use_proxy=True, managed_server() routes + # through the proxy HTTP API instead of creating in-process instances + self.proxy_url = proxy_url or os.environ.get("ATROPOS_PROXY_URL") + self.use_proxy = use_proxy or bool(self.proxy_url) + # Tool parser — passed to ManagedServer for tool call support + self.tool_parser = tool_parser # First we check to see if it's the base server class, and if so, we need to select the appropriate server class # You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as # an ABCMeta, not what you're expecting. @@ -364,8 +381,10 @@ class ServerManager: @asynccontextmanager async def managed_server( - self, tokenizer=None - ) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]: + self, + tokenizer=None, + base_url: Optional[str] = None, + ): """ Context manager that provides a ManagedServer instance. @@ -379,25 +398,63 @@ class ServerManager: Args: tokenizer: Optional tokenizer to use. If not provided, will attempt to extract from server or create from model name. + base_url: Pin the session to a specific backend server by its base_url. + In production, this comes from the atropos API's server allocation. Yields: - ManagedServer (or DummyManagedServer for OpenAI) instance wrapping - the selected server + ManagedServer, DummyManagedServer, or ProxyManagedServer instance Raises: NotImplementedError: If using OpenAI server without the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set. Example: + # In-process (default): async with server_manager.managed_server() as managed: response = await managed.chat_completion( messages=[{"role": "user", "content": "Hello"}], - n=2 + n=2, tools=[...], tool_choice="auto", ) state = managed.get_state() - # Process state... - # State is automatically cleared when exiting context + + # Via proxy (configured at init with proxy_url= or ATROPOS_PROXY_URL): + # server_manager = ServerManager(configs, proxy_url="http://proxy:9100") + async with server_manager.managed_server() as managed: + response = await managed.chat_completion(...) + api_url = managed.get_url() # for external apps """ + # -- Proxy path -- + if self.use_proxy: + resolved_proxy_url = self.proxy_url + if not resolved_proxy_url: + raise ValueError( + "use_proxy=True requires proxy_url or ATROPOS_PROXY_URL env var " + "to be set at ServerManager init" + ) + + from atroposlib.envs.server_handling.proxy_client import ( + create_proxy_session, + ) + + model_name = ( + self.servers[0].config.model_name + if self.servers and hasattr(self.servers[0], "config") + else "unknown" + ) + + proxy_managed = await create_proxy_session( + proxy_url=resolved_proxy_url, + base_url=base_url, + tool_parser=self.tool_parser or "hermes", + model_name=model_name, + ) + try: + yield proxy_managed + finally: + await proxy_managed.cleanup() + return + + # -- In-process path (existing logic) -- most_available_server = 0 most_available_server_num_slots = -1 for i, server in enumerate(self.servers): @@ -441,7 +498,11 @@ class ServerManager: finally: managed.reset() else: - managed = ManagedServer(server=selected_server, tokenizer=tokenizer) + managed = ManagedServer( + server=selected_server, + tokenizer=tokenizer, + tool_parser=self.tool_parser, + ) try: yield managed diff --git a/atroposlib/envs/server_handling/tool_call_translator.py b/atroposlib/envs/server_handling/tool_call_translator.py new file mode 100644 index 00000000..a7963649 --- /dev/null +++ b/atroposlib/envs/server_handling/tool_call_translator.py @@ -0,0 +1,251 @@ +""" +Bidirectional translation between OpenAI tool_calls format and raw model text. + +Uses vLLM's tool parsers directly — same parsing logic as vLLM's chat +completions endpoint. Supports 30+ model-specific parsers (hermes, llama, +mistral, deepseek, qwen3, etc.) via ToolParserManager. + +Outbound (model → client): raw text with tags → structured OpenAI tool_calls +Inbound (client → model): OpenAI messages with tool roles → raw text for chat template +""" + +import json +import logging +import warnings +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# vLLM is optional — tool call parsing degrades gracefully without it +try: + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + from vllm.entrypoints.openai.engine.protocol import ( + ExtractedToolCallInformation, + FunctionCall, + ToolCall, + ) + from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + ChatCompletionRequest = None + ExtractedToolCallInformation = None + FunctionCall = None + ToolCall = None + ToolParser = None + ToolParserManager = None + + +class ToolCallTranslator: + """Bidirectional translation between OpenAI tool_calls and raw model text. + + Uses vLLM's tool parsers directly for outbound parsing (model output → + OpenAI format). Maintains a lookup table mapping tool_call IDs back to + the raw text that produced them, for reconstructing messages when the + caller sends tool results back. + + The ManagedServer always stores raw text — this translator only + transforms what goes over the HTTP wire. + """ + + def __init__(self, tokenizer: Any, parser_name: str = "hermes"): + """ + Args: + tokenizer: HuggingFace tokenizer instance. + parser_name: Name of the vLLM tool parser to use. + Available: hermes, llama3_json, llama4_json, mistral, + deepseek_v3, qwen3_coder, granite, internlm, etc. + See ToolParserManager.list_registered() for full list. + + Raises: + Warning if vLLM is not installed — tool call parsing will be + disabled but the translator can still handle message conversion + and decoding. + """ + self.tokenizer = tokenizer + self.parser_name = parser_name + self.parser = None + + if not VLLM_AVAILABLE: + warnings.warn( + "vLLM is not installed — tool call parsing is disabled. " + "Install vllm to enable structured tool call extraction from " + "model output (pip install vllm). The translator will still " + "handle message conversion and template rendering, but " + "parse_model_output() will return raw text without parsing.", + stacklevel=2, + ) + else: + ParserClass = ToolParserManager.get_tool_parser(parser_name) + self.parser = ParserClass(tokenizer) + + # tool_call_id → raw text segment that produced it. + # Used to reconstruct assistant messages when the caller sends + # follow-up messages with tool results. + self.call_id_to_raw_text: Dict[str, str] = {} + + def parse_model_output( + self, + raw_text: str, + tool_choice: Optional[str] = None, + tools: Optional[List[dict]] = None, + ) -> Tuple[Optional[str], Optional[List[ToolCall]], str]: + """Parse raw model output into OpenAI response fields. + + Args: + raw_text: Raw model output text (may contain tags etc.) + tool_choice: The tool_choice from the request. If "none", skip + parsing entirely. If "required", force tool_calls interpretation. + tools: Tool definitions from the request (needed for vLLM request obj). + + Returns: + Tuple of (content, tool_calls, finish_reason): + content: Text content (before tool calls, or full text if no tools) + tool_calls: List of ToolCall objects, or None + finish_reason: "stop", "tool_calls", or "length" + """ + # If tool_choice is "none" or no tools defined, don't even try parsing + if tool_choice == "none" or not tools: + return raw_text, None, "stop" + + # If vLLM isn't available, can't parse — return raw text + if self.parser is None: + return raw_text, None, "stop" + + # Build a minimal ChatCompletionRequest for the parser + request = ChatCompletionRequest( + messages=[{"role": "user", "content": ""}], + model="proxy", + tools=tools, + tool_choice=tool_choice, + ) + + result: ExtractedToolCallInformation = self.parser.extract_tool_calls( + raw_text, request + ) + + if result.tools_called and result.tool_calls: + # Store mapping for reverse direction + for tc in result.tool_calls: + self.call_id_to_raw_text[tc.id] = raw_text + + return result.content, result.tool_calls, "tool_calls" + else: + return raw_text, None, "stop" + + def reconstruct_raw_text_from_tool_calls(self, tool_calls: List[dict]) -> str: + """Reconstruct raw model text from OpenAI-format tool_calls. + + When a caller sends an assistant message with tool_calls (e.g. in a + multi-turn conversation), we need to convert it back to the raw text + the model actually generated so the chat template produces the right + tokens. + + First tries the lookup table (exact reconstruction). Falls back to + rebuilding from the structured data (best-effort). + + Args: + tool_calls: List of tool call dicts from OpenAI format, each with + 'id', 'type', 'function' (containing 'name' and 'arguments'). + + Returns: + Raw text with tags (or whatever format the parser uses). + """ + if not tool_calls: + return "" + + # Try lookup table first — if the first call's ID is in the table, + # we can return the exact raw text + first_id = tool_calls[0].get("id", "") + if first_id in self.call_id_to_raw_text: + return self.call_id_to_raw_text[first_id] + + # Fallback: reconstruct from structured data + # This is best-effort — the exact formatting may differ from what + # the model originally generated, but it's close enough for the + # chat template to handle. + # TODO: make the tag format configurable per parser (hermes uses + # , others may differ) + parts = [] + for tc in tool_calls: + func = tc.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "{}") + # Parse arguments string back to dict for clean formatting + try: + args_dict = ( + json.loads(arguments) if isinstance(arguments, str) else arguments + ) + except (json.JSONDecodeError, TypeError): + args_dict = arguments + call_obj = {"name": name, "arguments": args_dict} + parts.append(f"{json.dumps(call_obj)}") + + return "\n".join(parts) + + def convert_messages_for_template( + self, messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Convert OpenAI messages to raw format suitable for apply_chat_template. + + Handles three cases: + 1. Regular messages (user, system): pass through unchanged + 2. Assistant messages with tool_calls: replace with raw text content + 3. Tool result messages (role=tool): pass through (chat template handles them) + + Args: + messages: OpenAI-format messages list. + + Returns: + Messages list with tool_call assistant messages reconstructed to raw text. + """ + converted = [] + for msg in messages: + role = msg.get("role", "") + + if role == "assistant" and msg.get("tool_calls"): + # Reconstruct raw text from tool_calls + raw_text = self.reconstruct_raw_text_from_tool_calls(msg["tool_calls"]) + # Prepend any content that came before the tool calls + content = msg.get("content") or "" + if content: + raw_text = content + "\n" + raw_text + + converted.append( + { + "role": "assistant", + "content": raw_text, + } + ) + else: + # Pass through as-is (user, system, tool, regular assistant) + converted.append(msg) + + return converted + + def decode_with_tool_awareness( + self, + token_ids: List[int], + has_tools: bool = False, + ) -> str: + """Decode token IDs, preserving tool call tags when tools are active. + + Some tokenizers mark as a special token. If we decode with + skip_special_tokens=True (the default), the tags vanish before the + parser ever sees them. When tools are in play, we decode with + skip_special_tokens=False to preserve them. + + Args: + token_ids: Token IDs to decode. + has_tools: Whether tools are active for this request. + + Returns: + Decoded text string. + """ + return self.tokenizer.decode( + token_ids, + skip_special_tokens=not has_tools, + ) diff --git a/atroposlib/tests/conftest.py b/atroposlib/tests/conftest.py index d122e39d..f2a2d234 100644 --- a/atroposlib/tests/conftest.py +++ b/atroposlib/tests/conftest.py @@ -5,19 +5,32 @@ def pytest_addoption(parser): parser.addoption( "--runproviders", action="store_true", default=False, help="run provider tests" ) + parser.addoption( + "--run-gpu", + action="store_true", + default=False, + help="run GPU integration tests", + ) def pytest_configure(config): config.addinivalue_line( "markers", "providers: mark test as requires providers api keys to run" ) + config.addinivalue_line( + "markers", "gpu: mark test as requiring GPU (skipped unless --run-gpu)" + ) def pytest_collection_modifyitems(config, items): - if config.getoption("--runproviders"): - # --runproviders given in cli: do not skip slow tests - return - skip_providers = pytest.mark.skip(reason="need --runproviders option to run") - for item in items: - if "providers" in item.keywords: - item.add_marker(skip_providers) + if not config.getoption("--runproviders"): + skip_providers = pytest.mark.skip(reason="need --runproviders option to run") + for item in items: + if "providers" in item.keywords: + item.add_marker(skip_providers) + + if not config.getoption("--run-gpu"): + skip_gpu = pytest.mark.skip(reason="need --run-gpu option to run") + for item in items: + if "gpu" in item.keywords: + item.add_marker(skip_gpu) diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 75f8d48c..4da52bb0 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -493,6 +493,295 @@ async def test_multi_turn_chat_with_branching(mock_server): assert f"More{actual_i}" in node.full_text # Has third turn +# --------------------------------------------------------------------------- +# Tool call support in ManagedServer.chat_completion() +# --------------------------------------------------------------------------- + + +class MockTokenizerWithTools(MockTokenizer): + """Extended mock tokenizer that supports tools kwarg in apply_chat_template.""" + + def apply_chat_template( + self, messages, tokenize=False, add_generation_prompt=True, tools=None + ): + result = "" + if tools: + import json + + result += f"{json.dumps(tools)}\n" + for msg in messages: + content = msg.get("content", "") or "" + result += f"<{msg['role']}>{content}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +@pytest.fixture +def mock_server_with_tools(): + """Mock server with tool-aware tokenizer.""" + server = ServerHarness() + server.tokenizer = MockTokenizerWithTools() + + class Config: + model_name = "test_model" + + server.config = Config() + return server + + +def _setup_chat_completion(server, tokenizer, messages, output_texts, tools=None): + """Helper: set up mock tokens_and_logprobs for a chat_completion call.""" + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + prompt_tokens = tokenizer.encode(prompt) + output_tokens_list = [[ord(c) for c in text] for text in output_texts] + output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list] + finish_reasons = ["stop"] * len(output_texts) + + server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=output_tokens_list, + output_logprobs_list=output_logprobs_list, + finish_reasons=finish_reasons, + ) + return prompt + + +@pytest.mark.asyncio +async def test_tool_call_parsing_outbound(mock_server_with_tools): + """Model generates → chat_completion returns structured tool_calls.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + tool_parser="hermes", + ) + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Search cats"}] + raw_output = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + + _setup_chat_completion( + mock_server_with_tools, + mock_server_with_tools.tokenizer, + messages, + [raw_output], + tools=tools, + ) + + result = await managed.chat_completion( + messages=messages, tools=tools, tool_choice="auto" + ) + + assert len(result.choices) == 1 + choice = result.choices[0] + assert choice.finish_reason == "tool_calls" + assert choice.message.tool_calls is not None + assert len(choice.message.tool_calls) == 1 + tc = choice.message.tool_calls[0] + assert tc["function"]["name"] == "search" + + # Node should have raw text (not parsed) + state = managed.get_state() + assert len(state["nodes"]) == 1 + + +@pytest.mark.asyncio +async def test_tool_choice_none_skips(mock_server_with_tools): + """tool_choice='none' returns raw text, no parsing.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + tool_parser="hermes", + ) + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Hi"}] + raw_output = '{"name": "search", "arguments": {"q": "x"}}' + + _setup_chat_completion( + mock_server_with_tools, + mock_server_with_tools.tokenizer, + messages, + [raw_output], + tools=tools, + ) + + result = await managed.chat_completion( + messages=messages, tools=tools, tool_choice="none" + ) + + assert result.choices[0].message.tool_calls is None + assert result.choices[0].finish_reason == "stop" + # Raw text should be content + assert "" in result.choices[0].message.content + + +@pytest.mark.asyncio +async def test_no_tool_parser_passes_through(mock_server_with_tools): + """Without tool_parser, tools kwarg is ignored — no parsing.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + # No tool_parser + ) + + messages = [{"role": "user", "content": "Hi"}] + raw_output = '{"name": "search", "arguments": {"q": "x"}}' + + _setup_chat_completion( + mock_server_with_tools, mock_server_with_tools.tokenizer, messages, [raw_output] + ) + + result = await managed.chat_completion(messages=messages) + + # No tool parsing — raw text as content + assert result.choices[0].message.tool_calls is None + assert result.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_tool_call_multi_turn_extends_node(mock_server_with_tools): + """Multi-turn with tool calls should extend to 1 node.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + tool_parser="hermes", + ) + tok = mock_server_with_tools.tokenizer + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + + # Step 1: user → tool_call + messages_1 = [{"role": "user", "content": "Search cats"}] + output_1 = '{"name": "search", "arguments": {"q": "cats"}}' + _setup_chat_completion( + mock_server_with_tools, tok, messages_1, [output_1], tools=tools + ) + + result_1 = await managed.chat_completion( + messages=messages_1, tools=tools, tool_choice="auto" + ) + tc_1 = result_1.choices[0].message.tool_calls + + assert len(managed.get_state()["nodes"]) == 1 + + # Step 2: include tool result → plain response + # Reconstruct the assistant message with tool_calls for the translator + messages_2 = [ + {"role": "user", "content": "Search cats"}, + {"role": "assistant", "content": None, "tool_calls": tc_1}, + {"role": "tool", "tool_call_id": tc_1[0]["id"], "content": "Found 5 cats"}, + ] + + # The translator will reconstruct the tool_call to raw text, + # so we need the prompt to match what it produces + output_2 = "Here are 5 cats!" + prompt_2 = tok.apply_chat_template( + managed._get_translator().convert_messages_for_template(messages_2), + tokenize=False, + add_generation_prompt=True, + tools=tools, + ) + prompt_tokens_2 = tok.encode(prompt_2) + output_tokens_2 = [ord(c) for c in output_2] + mock_server_with_tools.set_tokens_and_logprobs_response( + prompt=prompt_2, + prompt_tokens=prompt_tokens_2, + output_tokens_list=[output_tokens_2], + output_logprobs_list=[[-0.1] * len(output_tokens_2)], + finish_reasons=["stop"], + ) + + result_2 = await managed.chat_completion( + messages=messages_2, tools=tools, tool_choice="auto" + ) + assert result_2.choices[0].message.content == output_2 + + # Still 1 node — step 2 extended step 1 + assert len(managed.get_state()["nodes"]) == 1 + + +@pytest.mark.asyncio +async def test_tool_call_multiple_tools_parsed(mock_server_with_tools): + """Multiple tool calls in one response are all parsed.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + tool_parser="hermes", + ) + + tools = [ + {"type": "function", "function": {"name": "get_weather", "parameters": {}}}, + {"type": "function", "function": {"name": "get_time", "parameters": {}}}, + ] + messages = [{"role": "user", "content": "Weather and time?"}] + raw_output = ( + '{"name": "get_weather", "arguments": {"city": "SF"}}\n' + '{"name": "get_time", "arguments": {"tz": "PST"}}' + ) + _setup_chat_completion( + mock_server_with_tools, + mock_server_with_tools.tokenizer, + messages, + [raw_output], + tools=tools, + ) + + result = await managed.chat_completion( + messages=messages, tools=tools, tool_choice="auto" + ) + + assert result.choices[0].finish_reason == "tool_calls" + assert len(result.choices[0].message.tool_calls) == 2 + names = {tc["function"]["name"] for tc in result.choices[0].message.tool_calls} + assert names == {"get_weather", "get_time"} + + +@pytest.mark.asyncio +async def test_tool_call_node_masking(mock_server_with_tools): + """Nodes have proper masking even with tool parsing active.""" + managed = ManagedServer( + mock_server_with_tools, + tokenizer=mock_server_with_tools.tokenizer, + tool_parser="hermes", + ) + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Hi"}] + raw_output = '{"name": "search", "arguments": {"q": "x"}}' + + _setup_chat_completion( + mock_server_with_tools, + mock_server_with_tools.tokenizer, + messages, + [raw_output], + tools=tools, + ) + + await managed.chat_completion(messages=messages, tools=tools) + + node = managed.get_state()["nodes"][0] + + # Lengths must match + assert len(node.tokens) == len(node.masked_tokens) == len(node.logprobs) + + # Should have masked prompt tokens and actual completion tokens + num_masked = sum(1 for t in node.masked_tokens if t == -100) + num_actual = sum(1 for t in node.masked_tokens if t != -100) + assert num_masked > 0 + assert num_actual > 0 + + # Prompt logprobs = 1.0, completion logprobs < 0 + assert all(lp == 1.0 for lp in node.logprobs[:num_masked]) + assert all(lp < 0 for lp in node.logprobs[num_masked:]) + + if __name__ == "__main__": # Run tests pytest.main([__file__, "-v"]) diff --git a/atroposlib/tests/test_managed_server_proxy.py b/atroposlib/tests/test_managed_server_proxy.py new file mode 100644 index 00000000..f210a2e2 --- /dev/null +++ b/atroposlib/tests/test_managed_server_proxy.py @@ -0,0 +1,852 @@ +"""Mock-based tests for the ManagedServer OpenAI proxy. + +Uses ServerHarness as the backend — no real model or GPU needed. +Tests the full HTTP layer: session management, chat completions, +tool call translation, render endpoint, nodes, cleanup. +""" + +import json + +import pytest +from fastapi.testclient import TestClient + +from atroposlib.envs.server_handling.managed_server_proxy import create_app +from atroposlib.envs.server_handling.server_harness import ServerHarness +from atroposlib.envs.server_handling.server_manager import ServerManager + +# --------------------------------------------------------------------------- +# Mock tokenizer (same as test_managed_server.py / test_tool_call_translator.py) +# --------------------------------------------------------------------------- + + +class MockTokenizer: + def __init__(self): + self.eos_token_id = 2 + self.bos_token_id = 1 + + def encode(self, text, add_special_tokens=True): + tokens = [ord(c) for c in text] + if add_special_tokens: + tokens = [self.bos_token_id] + tokens + return tokens + + def decode(self, tokens, skip_special_tokens=False): + if skip_special_tokens: + tokens = [ + t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] + ] + return "".join([chr(t) if t > 31 else "" for t in tokens]) + + def get_vocab(self): + return {chr(i): i for i in range(128)} + + def apply_chat_template( + self, messages, tokenize=False, add_generation_prompt=True, tools=None + ): + result = "" + if tools: + result += f"{json.dumps(tools)}\n" + for msg in messages: + content = msg.get("content", "") or "" + result += f"<{msg['role']}>{content}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_backend(): + """Create a mock server backend with tokenizer.""" + server = ServerHarness() + server.tokenizer = MockTokenizer() + # ServerManager's _select_server checks these attributes + server.server_healthy = True + + class Config: + model_name = "test_model" + + server.config = Config() + return server + + +@pytest.fixture +def server_manager(mock_backend): + """Create a ServerManager wrapping the mock backend.""" + # Can't use ServerManager constructor with empty configs, so build manually + mgr = object.__new__(ServerManager) + mgr.max_n_completions = 8 + mgr.reasoning_config = None + mgr.servers = [mock_backend] + return mgr + + +@pytest.fixture +def client(server_manager): + """Create a test client for the proxy app.""" + tokenizer = MockTokenizer() + app = create_app( + server_manager=server_manager, + tokenizer=tokenizer, + model_name="test_model", + ) + return TestClient(app) + + +@pytest.fixture +def client_and_backend(mock_backend, server_manager): + """Return both client and backend for tests that need to set up mock responses.""" + tokenizer = MockTokenizer() + app = create_app( + server_manager=server_manager, + tokenizer=tokenizer, + model_name="test_model", + ) + return TestClient(app), mock_backend, tokenizer + + +def _setup_completion( + backend, tokenizer, prompt_text, output_texts, finish_reasons=None +): + """Helper to set up a mock tokens_and_logprobs response.""" + prompt_tokens = tokenizer.encode(prompt_text) + output_tokens_list = [[ord(c) for c in text] for text in output_texts] + output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list] + if finish_reasons is None: + finish_reasons = ["stop"] * len(output_texts) + + backend.set_tokens_and_logprobs_response( + prompt=prompt_text, + prompt_tokens=prompt_tokens, + output_tokens_list=output_tokens_list, + output_logprobs_list=output_logprobs_list, + finish_reasons=finish_reasons, + ) + + +# --------------------------------------------------------------------------- +# Health / Models +# --------------------------------------------------------------------------- + + +class TestHealth: + def test_health(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["model"] == "test_model" + assert data["sessions"] == 0 + + def test_models(self, client): + resp = client.get("/v1/models") + assert resp.status_code == 200 + data = resp.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "test_model" + + +# --------------------------------------------------------------------------- +# Session Management +# --------------------------------------------------------------------------- + + +class TestSessionManagement: + def test_create_session(self, client): + resp = client.post("/sessions/create", json={}) + assert resp.status_code == 200 + data = resp.json() + assert "uuid" in data + assert data["model_name"] == "test_model" + assert data["tool_parser"] == "hermes" + + def test_create_session_custom_parser(self, client): + resp = client.post("/sessions/create", json={"tool_parser": "hermes"}) + assert resp.status_code == 200 + assert resp.json()["tool_parser"] == "hermes" + + def test_list_sessions(self, client): + # Create 3 sessions + uuids = [] + for _ in range(3): + resp = client.post("/sessions/create", json={}) + uuids.append(resp.json()["uuid"]) + + resp = client.get("/sessions") + assert resp.status_code == 200 + sessions = resp.json()["sessions"] + assert len(sessions) == 3 + listed_uuids = {s["uuid"] for s in sessions} + assert listed_uuids == set(uuids) + + def test_delete_session(self, client): + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = client.delete(f"/{uuid}") + assert resp.status_code == 200 + assert resp.json()["status"] == "deleted" + + # Should be gone + resp = client.get(f"/{uuid}/nodes") + assert resp.status_code == 404 + + def test_delete_nonexistent_session(self, client): + resp = client.delete("/nonexistent-uuid") + assert resp.status_code == 404 + + def test_session_not_found(self, client): + resp = client.post( + "/nonexistent-uuid/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Chat Completions +# --------------------------------------------------------------------------- + + +class TestChatCompletions: + def test_basic_completion(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + # Create session + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + # Set up mock response + messages = [{"role": "user", "content": "Hello"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + _setup_completion(backend, tokenizer, prompt_text, ["Hi there!"]) + + # Make request + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages, "max_tokens": 100}, + ) + assert resp.status_code == 200 + data = resp.json() + + assert data["object"] == "chat.completion" + assert data["model"] == "test_model" + assert len(data["choices"]) == 1 + assert data["choices"][0]["message"]["role"] == "assistant" + assert data["choices"][0]["message"]["content"] == "Hi there!" + assert data["choices"][0]["finish_reason"] == "stop" + assert data["id"].startswith("chatcmpl-") + + def test_completion_with_n(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [{"role": "user", "content": "Pick a number"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + _setup_completion(backend, tokenizer, prompt_text, ["One", "Two", "Three"]) + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages, "n": 3, "max_tokens": 50}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["choices"]) == 3 + + def test_empty_messages_error(self, client): + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": []}, + ) + assert resp.status_code == 400 + + def test_completion_with_system_prompt(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + _setup_completion(backend, tokenizer, prompt_text, ["Hello!"]) + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages}, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"]["content"] == "Hello!" + + +# --------------------------------------------------------------------------- +# Tool Call Handling +# --------------------------------------------------------------------------- + + +class TestToolCalls: + def test_tool_call_outbound(self, client_and_backend): + """Model generates tags → response has structured tool_calls.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Search cats"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + # Model output includes tool call tags + raw_output = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + _setup_completion(backend, tokenizer, prompt_text, [raw_output]) + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages, "tools": tools}, + ) + assert resp.status_code == 200 + data = resp.json() + choice = data["choices"][0] + + assert choice["finish_reason"] == "tool_calls" + assert "tool_calls" in choice["message"] + assert len(choice["message"]["tool_calls"]) == 1 + tc = choice["message"]["tool_calls"][0] + assert tc["function"]["name"] == "search" + assert json.loads(tc["function"]["arguments"]) == {"query": "cats"} + + def test_tool_choice_none(self, client_and_backend): + """tool_choice=none → no parsing, raw text returned.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Search cats"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + raw_output = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + _setup_completion(backend, tokenizer, prompt_text, [raw_output]) + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages, "tools": tools, "tool_choice": "none"}, + ) + assert resp.status_code == 200 + choice = resp.json()["choices"][0] + + # Should NOT have tool_calls since tool_choice is "none" + assert choice["finish_reason"] == "stop" + assert ( + "tool_calls" not in choice["message"] + or choice["message"].get("tool_calls") is None + ) + + def test_nodes_preserve_raw_text(self, client_and_backend): + """ManagedServer nodes should have raw text, not parsed tool_calls.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + messages = [{"role": "user", "content": "Search cats"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + raw_output = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + _setup_completion(backend, tokenizer, prompt_text, [raw_output]) + + client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages, "tools": tools}, + ) + + # Check nodes — should have the raw tokens, not parsed + resp = client.get(f"/{uuid}/nodes") + assert resp.status_code == 200 + nodes = resp.json()["nodes"] + assert len(nodes) == 1 + + +# --------------------------------------------------------------------------- +# Render Endpoint +# --------------------------------------------------------------------------- + + +class TestRender: + def test_render_basic(self, client): + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = client.post( + f"/{uuid}/v1/chat/completions/render", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "prompt_text" in data + assert "token_ids" in data + assert "num_tokens" in data + assert data["num_tokens"] == len(data["token_ids"]) + assert "Hello" in data["prompt_text"] + + def test_render_with_tools(self, client): + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [{"type": "function", "function": {"name": "search"}}] + resp = client.post( + f"/{uuid}/v1/chat/completions/render", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": tools, + }, + ) + assert resp.status_code == 200 + data = resp.json() + # Tool definitions should appear in the rendered prompt + assert "search" in data["prompt_text"] + + def test_render_does_not_create_nodes(self, client): + """Render should not cause any generation or node creation.""" + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + client.post( + f"/{uuid}/v1/chat/completions/render", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + resp = client.get(f"/{uuid}/nodes") + assert resp.json()["nodes"] == [] + + +# --------------------------------------------------------------------------- +# Nodes +# --------------------------------------------------------------------------- + + +class TestNodes: + def test_get_nodes_empty(self, client): + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = client.get(f"/{uuid}/nodes") + assert resp.status_code == 200 + assert resp.json()["nodes"] == [] + + def test_get_nodes_after_completion(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [{"role": "user", "content": "Hi"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + _setup_completion(backend, tokenizer, prompt_text, ["Hello!"]) + + client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages}, + ) + + resp = client.get(f"/{uuid}/nodes") + assert resp.status_code == 200 + nodes = resp.json()["nodes"] + assert len(nodes) == 1 + + node = nodes[0] + assert "tokens" in node + assert "masked_tokens" in node + assert "logprobs" in node + assert "full_text" in node + assert ( + len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"]) + ) + + def test_nodes_have_proper_masking(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [{"role": "user", "content": "Hi"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_tokens = tokenizer.encode(prompt_text) + prompt_len = len(prompt_tokens) + + _setup_completion(backend, tokenizer, prompt_text, ["Hello!"]) + + client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages}, + ) + + resp = client.get(f"/{uuid}/nodes") + node = resp.json()["nodes"][0] + + # Prompt tokens should be masked with -100 + assert all(t == -100 for t in node["masked_tokens"][:prompt_len]) + # Prompt logprobs should be 1.0 + assert all(lp == 1.0 for lp in node["logprobs"][:prompt_len]) + # Completion logprobs should be actual values (negative) + assert all(lp < 0 for lp in node["logprobs"][prompt_len:]) + + +# --------------------------------------------------------------------------- +# Deep multi-step node handling +# --------------------------------------------------------------------------- + + +class TestMultiStepNodeHandling: + """Test that multi-step conversations with tool calls produce exactly 1 node. + + Simulates a realistic 10+ message agentic conversation: + user → assistant(tool_call) → tool_result → assistant(text) → + user → assistant(tool_call) → tool_result → assistant(tool_call) → + tool_result → assistant(text) → user → assistant(text) + + Each step extends the previous node, so we should end up with exactly + 1 node containing the full tokenized conversation. + """ + + def _do_step( + self, + client, + backend, + tokenizer, + uuid, + messages, + output_text, + tools=None, + expect_tool_calls=False, + ): + """Helper: use render endpoint to get exact prompt, set up mock, call endpoint.""" + body = {"messages": messages, "max_tokens": 200} + if tools: + body["tools"] = tools + + # Use the render endpoint to get the exact prompt the proxy will generate + # (this includes tool_call reconstruction through the translator) + render_resp = client.post(f"/{uuid}/v1/chat/completions/render", json=body) + assert render_resp.status_code == 200, f"Render failed: {render_resp.json()}" + prompt_text = render_resp.json()["prompt_text"] + + _setup_completion(backend, tokenizer, prompt_text, [output_text]) + + resp = client.post(f"/{uuid}/v1/chat/completions", json=body) + assert resp.status_code == 200, f"Step failed: {resp.json()}" + return resp.json() + + def test_10_message_conversation_one_node(self, client_and_backend): + """Full 10-message conversation with tool calls → exactly 1 node.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [ + {"type": "function", "function": {"name": "get_weather", "parameters": {}}}, + { + "type": "function", + "function": {"name": "get_forecast", "parameters": {}}, + }, + ] + + # -- Step 1: user asks about weather -- + messages = [{"role": "user", "content": "What's the weather in SF?"}] + output_1 = '{"name": "get_weather", "arguments": {"city": "SF"}}' + data = self._do_step( + client, backend, tokenizer, uuid, messages, output_1, tools=tools + ) + assert data["choices"][0]["finish_reason"] == "tool_calls" + tc_1 = data["choices"][0]["message"]["tool_calls"] + + # Check: 1 node so far + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after step 1, got {len(nodes)}" + + # -- Step 2: tool result -- + messages = [ + {"role": "user", "content": "What's the weather in SF?"}, + {"role": "assistant", "content": None, "tool_calls": tc_1}, + { + "role": "tool", + "tool_call_id": tc_1[0]["id"], + "content": "72°F and sunny", + }, + ] + output_2 = "The weather in SF is 72°F and sunny! Want the forecast too?" + self._do_step(client, backend, tokenizer, uuid, messages, output_2, tools=tools) + + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after step 2, got {len(nodes)}" + + # -- Step 3: user says yes -- + messages.extend( + [ + {"role": "assistant", "content": output_2}, + {"role": "user", "content": "Yes please, get the forecast"}, + ] + ) + output_3 = '{"name": "get_forecast", "arguments": {"city": "SF"}}' + data = self._do_step( + client, backend, tokenizer, uuid, messages, output_3, tools=tools + ) + tc_3 = data["choices"][0]["message"]["tool_calls"] + + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after step 3, got {len(nodes)}" + + # -- Step 4: forecast tool result -- + messages.extend( + [ + {"role": "assistant", "content": None, "tool_calls": tc_3}, + { + "role": "tool", + "tool_call_id": tc_3[0]["id"], + "content": "Rain expected tomorrow", + }, + ] + ) + output_4 = "The forecast says rain is expected tomorrow in SF." + self._do_step(client, backend, tokenizer, uuid, messages, output_4, tools=tools) + + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after step 4, got {len(nodes)}" + + # -- Step 5: user asks about another city -- + messages.extend( + [ + {"role": "assistant", "content": output_4}, + {"role": "user", "content": "What about NYC?"}, + ] + ) + output_5 = '{"name": "get_weather", "arguments": {"city": "NYC"}}' + data = self._do_step( + client, backend, tokenizer, uuid, messages, output_5, tools=tools + ) + tc_5 = data["choices"][0]["message"]["tool_calls"] + + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after step 5, got {len(nodes)}" + + # -- Step 6: NYC tool result -- + messages.extend( + [ + {"role": "assistant", "content": None, "tool_calls": tc_5}, + { + "role": "tool", + "tool_call_id": tc_5[0]["id"], + "content": "55°F and cloudy", + }, + ] + ) + output_6 = "NYC is 55°F and cloudy. Quite different from SF!" + self._do_step(client, backend, tokenizer, uuid, messages, output_6, tools=tools) + + # -- FINAL CHECK: still exactly 1 node after 6 completions / 12+ messages -- + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert ( + len(nodes) == 1 + ), f"Expected 1 node after full conversation, got {len(nodes)}" + + # Verify the node has proper structure + node = nodes[0] + assert ( + len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"]) + ) + assert len(node["tokens"]) > 0 + + # Verify masking: there should be SOME -100 (prompt) and SOME actual tokens + num_masked = sum(1 for t in node["masked_tokens"] if t == -100) + num_actual = sum(1 for t in node["masked_tokens"] if t != -100) + assert num_masked > 0, "Should have masked prompt tokens" + assert num_actual > 0, "Should have unmasked completion tokens" + + def test_plain_multi_turn_no_tools_one_node(self, client_and_backend): + """5-turn conversation without tools → exactly 1 node.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + conversation = [] + + for i in range(5): + # Add user message + conversation.append({"role": "user", "content": f"Turn {i+1} question"}) + + prompt_text = tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True + ) + output = f"Response to turn {i+1}" + _setup_completion(backend, tokenizer, prompt_text, [output]) + + resp = client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": conversation}, + ) + assert resp.status_code == 200 + + # Add assistant response for next turn + conversation.append({"role": "assistant", "content": output}) + + # After 5 turns (10 messages), should still be 1 node + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node after 5 turns, got {len(nodes)}" + + def test_tool_then_plain_then_tool_one_node(self, client_and_backend): + """Mixed: tool call → plain text → tool call → plain → exactly 1 node.""" + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + + # Step 1: tool call + messages = [{"role": "user", "content": "Search for cats"}] + output = '{"name": "search", "arguments": {"q": "cats"}}' + data = self._do_step( + client, backend, tokenizer, uuid, messages, output, tools=tools + ) + tc = data["choices"][0]["message"]["tool_calls"] + + # Step 2: tool result → plain response + messages = [ + {"role": "user", "content": "Search for cats"}, + {"role": "assistant", "content": None, "tool_calls": tc}, + {"role": "tool", "tool_call_id": tc[0]["id"], "content": "Found 10 cats"}, + ] + self._do_step( + client, backend, tokenizer, uuid, messages, "Here are 10 cats!", tools=tools + ) + + # Step 3: user asks for more → another tool call + messages.extend( + [ + {"role": "assistant", "content": "Here are 10 cats!"}, + {"role": "user", "content": "Search for dogs too"}, + ] + ) + output = '{"name": "search", "arguments": {"q": "dogs"}}' + data = self._do_step( + client, backend, tokenizer, uuid, messages, output, tools=tools + ) + tc2 = data["choices"][0]["message"]["tool_calls"] + + # Step 4: tool result → plain response + messages.extend( + [ + {"role": "assistant", "content": None, "tool_calls": tc2}, + { + "role": "tool", + "tool_call_id": tc2[0]["id"], + "content": "Found 5 dogs", + }, + ] + ) + self._do_step( + client, backend, tokenizer, uuid, messages, "Found 5 dogs too!", tools=tools + ) + + # Step 5: plain follow-up, no tools + messages.extend( + [ + {"role": "assistant", "content": "Found 5 dogs too!"}, + {"role": "user", "content": "Thanks!"}, + ] + ) + self._do_step( + client, backend, tokenizer, uuid, messages, "You're welcome!", tools=tools + ) + + # 5 completion steps, 11 messages — still 1 node + nodes = client.get(f"/{uuid}/nodes").json()["nodes"] + assert len(nodes) == 1, f"Expected 1 node, got {len(nodes)}" + + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + + +class TestCleanup: + def test_delete_resets_nodes(self, client_and_backend): + client, backend, tokenizer = client_and_backend + + resp = client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [{"role": "user", "content": "Hi"}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + _setup_completion(backend, tokenizer, prompt_text, ["Hello!"]) + + client.post( + f"/{uuid}/v1/chat/completions", + json={"messages": messages}, + ) + + # Delete + resp = client.delete(f"/{uuid}") + assert resp.status_code == 200 + + # Session gone + resp = client.get(f"/{uuid}/nodes") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Error format +# --------------------------------------------------------------------------- + + +class TestErrorFormat: + def test_404_is_openai_format(self, client): + resp = client.get("/nonexistent-uuid/nodes") + assert resp.status_code == 404 + data = resp.json() + assert "error" in data + assert "message" in data["error"] + assert "type" in data["error"] + assert "code" in data["error"] diff --git a/atroposlib/tests/test_managed_server_proxy_integration.py b/atroposlib/tests/test_managed_server_proxy_integration.py new file mode 100644 index 00000000..a9cd1d96 --- /dev/null +++ b/atroposlib/tests/test_managed_server_proxy_integration.py @@ -0,0 +1,363 @@ +"""Integration tests for ManagedServer OpenAI proxy against real vLLM backend. + +Spins up example_trainer/vllm_api_server.py with Qwen3-4B as a subprocess. +Requires GPU — skipped by default. Run with: + + pytest --run-gpu atroposlib/tests/test_managed_server_proxy_integration.py -v -s + +""" + +import os +import signal +import subprocess +import sys +import time + +import pytest +import requests +from transformers import AutoTokenizer + +from atroposlib.envs.server_handling.managed_server_proxy import create_app +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.server_manager import ServerManager + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +VLLM_PORT = 8123 +VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507" +PROXY_MODEL = VLLM_MODEL +VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1" +REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..") +VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py") +VENV_PYTHON = sys.executable # use the current interpreter + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def vllm_backend(): + """Start vLLM api server as a subprocess. Module-scoped so it's shared.""" + cmd = [ + VENV_PYTHON, + VLLM_SCRIPT, + "--model", + VLLM_MODEL, + "--port", + str(VLLM_PORT), + "--max-model-len", + "32000", + "--max-num-seqs", + "32", + ] + + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=REPO_ROOT, + ) + + # Wait for health + deadline = time.time() + 180 # 3 min for model loading + healthy = False + while time.time() < deadline: + try: + resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2) + if resp.status_code == 200: + healthy = True + break + except (requests.ConnectionError, requests.Timeout): + pass + + if proc.poll() is not None: + stdout = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail( + f"vLLM server exited early (code={proc.returncode}):\n{stdout[-3000:]}" + ) + + time.sleep(3) + + if not healthy: + proc.kill() + stdout = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"vLLM server didn't become healthy within 180s:\n{stdout[-3000:]}") + + yield proc + + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + + +@pytest.fixture(scope="module") +def tokenizer_instance(): + return AutoTokenizer.from_pretrained(VLLM_MODEL) + + +@pytest.fixture(scope="module") +def proxy_client(vllm_backend, tokenizer_instance): + """Create a test client for the proxy backed by the real vLLM server.""" + from fastapi.testclient import TestClient + + config = APIServerConfig( + model_name=VLLM_MODEL, + base_url=VLLM_BASE_URL, + api_key="", + server_type="vllm", + health_check=False, + ) + server_manager = ServerManager(configs=[config]) + + app = create_app( + server_manager=server_manager, + tokenizer=tokenizer_instance, + model_name=VLLM_MODEL, + ) + return TestClient(app) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +class TestRealChatCompletion: + def test_basic_completion(self, proxy_client): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 30, + "temperature": 0.0, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["choices"]) == 1 + content = data["choices"][0]["message"]["content"] + assert content is not None + assert len(content) > 0 + assert data["model"] == VLLM_MODEL + + def test_n_completions(self, proxy_client): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Pick a random number"}], + "max_tokens": 20, + "temperature": 1.0, + "n": 4, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["choices"]) == 4 + + # Check nodes + resp = proxy_client.get(f"/{uuid}/nodes") + assert len(resp.json()["nodes"]) == 4 + + +@pytest.mark.gpu +class TestRealLogprobs: + def test_logprobs_are_valid(self, proxy_client): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 20, + "temperature": 0.0, + }, + ) + + resp = proxy_client.get(f"/{uuid}/nodes") + nodes = resp.json()["nodes"] + assert len(nodes) == 1 + + node = nodes[0] + # Find where completion starts (logprobs transition from 1.0 to negative) + prompt_end = 0 + for i, lp in enumerate(node["logprobs"]): + if lp != 1.0: + prompt_end = i + break + + # Prompt logprobs should be 1.0 + assert all(lp == 1.0 for lp in node["logprobs"][:prompt_end]) + # Completion logprobs should be negative + completion_lps = [lp for lp in node["logprobs"][prompt_end:] if lp != 1.0] + assert len(completion_lps) > 0 + assert all(lp < 0 for lp in completion_lps) + + +@pytest.mark.gpu +class TestRealTokenAlignment: + def test_tokens_decode_to_full_text(self, proxy_client, tokenizer_instance): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Say exactly: test123"}], + "max_tokens": 30, + "temperature": 0.0, + }, + ) + + resp = proxy_client.get(f"/{uuid}/nodes") + node = resp.json()["nodes"][0] + + # Lengths must match + assert len(node["tokens"]) == len(node["masked_tokens"]) + assert len(node["tokens"]) == len(node["logprobs"]) + + # Decode tokens and check they match full_text + decoded = tokenizer_instance.decode(node["tokens"]) + # The decoded text should be close to (or contain) the full_text + # Exact match may differ due to special token handling, but content should match + assert len(decoded) > 0 + + +@pytest.mark.gpu +class TestRealRender: + def test_render_matches_tokenizer(self, proxy_client, tokenizer_instance): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello!"}, + ] + + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions/render", + json={"messages": messages}, + ) + assert resp.status_code == 200 + data = resp.json() + + # Compare with direct tokenizer rendering + expected_text = tokenizer_instance.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + assert data["prompt_text"] == expected_text + + +@pytest.mark.gpu +class TestRealSequenceExtension: + def test_multi_turn_extends(self, proxy_client): + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + # Turn 1 + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 20, + "temperature": 0.0, + }, + ) + assert resp.status_code == 200 + turn1_content = resp.json()["choices"][0]["message"]["content"] + + # Turn 2 — extends turn 1 + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say hello"}, + {"role": "assistant", "content": turn1_content}, + {"role": "user", "content": "Now say goodbye"}, + ], + "max_tokens": 20, + "temperature": 0.0, + }, + ) + assert resp.status_code == 200 + + # Should have nodes (extension behavior depends on prefix matching) + resp = proxy_client.get(f"/{uuid}/nodes") + nodes = resp.json()["nodes"] + assert len(nodes) >= 1 + + +@pytest.mark.gpu +class TestRealConcurrentSessions: + def test_sessions_independent(self, proxy_client): + """Multiple sessions should not contaminate each other.""" + uuids = [] + for _ in range(3): + resp = proxy_client.post("/sessions/create", json={}) + uuids.append(resp.json()["uuid"]) + + # Complete on each + for i, uuid in enumerate(uuids): + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": f"Count to {i+1}"}], + "max_tokens": 30, + "temperature": 0.0, + }, + ) + assert resp.status_code == 200 + + # Each should have exactly 1 node + for uuid in uuids: + resp = proxy_client.get(f"/{uuid}/nodes") + assert len(resp.json()["nodes"]) == 1 + + +@pytest.mark.gpu +class TestRealOpenAIClientCompat: + def test_openai_client_works(self, proxy_client): + """Verify the standard openai Python client can talk to our proxy.""" + resp = proxy_client.post("/sessions/create", json={}) + uuid = resp.json()["uuid"] + + # The TestClient doesn't expose a real port, so we test the + # response format is compatible by checking structure manually + resp = proxy_client.post( + f"/{uuid}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 10, + "temperature": 0.0, + }, + ) + data = resp.json() + + # Verify all fields the openai client expects + assert "id" in data + assert "object" in data + assert data["object"] == "chat.completion" + assert "created" in data + assert "model" in data + assert "choices" in data + assert isinstance(data["choices"], list) + for choice in data["choices"]: + assert "index" in choice + assert "message" in choice + assert "finish_reason" in choice + assert "role" in choice["message"] + assert "content" in choice["message"] diff --git a/atroposlib/tests/test_tool_call_translator.py b/atroposlib/tests/test_tool_call_translator.py new file mode 100644 index 00000000..ba17a9b3 --- /dev/null +++ b/atroposlib/tests/test_tool_call_translator.py @@ -0,0 +1,478 @@ +"""Unit tests for ToolCallTranslator — vLLM parser wrapper and lookup table. + +These are pure logic tests, no server or model needed. Uses a mock tokenizer. +""" + +import json + +import pytest + +from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator + +# --------------------------------------------------------------------------- +# Mock tokenizer (same one from test_managed_server.py) +# --------------------------------------------------------------------------- + + +class MockTokenizer: + def __init__(self): + self.eos_token_id = 2 + self.bos_token_id = 1 + + def encode(self, text, add_special_tokens=True): + tokens = [ord(c) for c in text] + if add_special_tokens: + tokens = [self.bos_token_id] + tokens + return tokens + + def decode(self, tokens, skip_special_tokens=False): + if skip_special_tokens: + tokens = [ + t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] + ] + return "".join([chr(t) if t > 31 else "" for t in tokens]) + + def get_vocab(self): + # Minimal vocab for the parser — hermes parser calls this + return {chr(i): i for i in range(128)} + + def apply_chat_template( + self, messages, tokenize=False, add_generation_prompt=True, tools=None + ): + result = "" + if tools: + result += f"{json.dumps(tools)}\n" + for msg in messages: + result += f"<{msg['role']}>{msg.get('content', '')}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +@pytest.fixture +def translator(): + tok = MockTokenizer() + return ToolCallTranslator(tokenizer=tok, parser_name="hermes") + + +# --------------------------------------------------------------------------- +# Outbound: model output → OpenAI format +# --------------------------------------------------------------------------- + + +class TestParseModelOutput: + def test_single_tool_call(self, translator): + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "tool_calls" + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "search" + assert json.loads(tool_calls[0].function.arguments) == {"query": "cats"} + # content is None or empty when full text is a tool call + assert content is None or content.strip() == "" + + def test_multiple_tool_calls(self, translator): + raw = ( + '{"name": "get_weather", "arguments": {"city": "SF"}}\n' + '{"name": "get_time", "arguments": {"tz": "PST"}}' + ) + tools = [ + {"type": "function", "function": {"name": "get_weather"}}, + {"type": "function", "function": {"name": "get_time"}}, + ] + content, tool_calls, finish_reason = translator.parse_model_output( + raw, tool_choice="auto", tools=tools + ) + + assert finish_reason == "tool_calls" + assert len(tool_calls) == 2 + names = {tc.function.name for tc in tool_calls} + assert names == {"get_weather", "get_time"} + + def test_no_tool_calls(self, translator): + raw = "The weather in SF is 72 degrees." + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "stop" + assert tool_calls is None + assert content == raw + + def test_content_before_tool_call(self, translator): + raw = 'Let me search for that.\n{"name": "search", "arguments": {"query": "cats"}}' + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "tool_calls" + assert tool_calls is not None + assert len(tool_calls) == 1 + # Content before the tool call tag should be preserved + assert content is not None + assert "search for that" in content + + def test_tool_choice_none_skips_parsing(self, translator): + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="none", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "stop" + assert tool_calls is None + assert content == raw # Raw text returned as-is + + def test_no_tools_skips_parsing(self, translator): + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + content, tool_calls, finish_reason = translator.parse_model_output( + raw, tool_choice="auto", tools=None + ) + + assert finish_reason == "stop" + assert tool_calls is None + assert content == raw + + def test_malformed_json_graceful_fallback(self, translator): + raw = "not valid json at all" + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + # Parser should handle gracefully — either no tools or raw content + assert finish_reason == "stop" + assert tool_calls is None + + def test_unclosed_tool_call(self, translator): + raw = '{"name": "search", "arguments": {"query": "cats"}}' + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + # The hermes regex has a branch for unclosed tags + assert finish_reason == "tool_calls" + assert tool_calls is not None + assert len(tool_calls) == 1 + + def test_nested_json_arguments(self, translator): + args = { + "filter": { + "type": "date", + "range": {"start": "2024-01-01", "end": "2024-12-31"}, + } + } + raw = f'{{"name": "search", "arguments": {json.dumps(args)}}}' + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "tool_calls" + assert json.loads(tool_calls[0].function.arguments) == args + + def test_tool_call_with_think_tags(self, translator): + raw = ( + "I should search for this information.\n" + '{"name": "search", "arguments": {"query": "cats"}}' + ) + content, tool_calls, finish_reason = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert finish_reason == "tool_calls" + assert tool_calls is not None + # Think content should be in the content field + if content: + assert "think" in content or "search for this" in content + + +# --------------------------------------------------------------------------- +# Lookup table +# --------------------------------------------------------------------------- + + +class TestLookupTable: + def test_parse_populates_lookup(self, translator): + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + _, tool_calls, _ = translator.parse_model_output( + raw, + tool_choice="auto", + tools=[{"type": "function", "function": {"name": "search"}}], + ) + + assert len(translator.call_id_to_raw_text) == 1 + tc_id = tool_calls[0].id + assert tc_id in translator.call_id_to_raw_text + assert translator.call_id_to_raw_text[tc_id] == raw + + def test_lookup_accumulates(self, translator): + tools = [{"type": "function", "function": {"name": "search"}}] + + raw1 = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + _, tc1, _ = translator.parse_model_output(raw1, tool_choice="auto", tools=tools) + + raw2 = ( + '{"name": "search", "arguments": {"query": "dogs"}}' + ) + _, tc2, _ = translator.parse_model_output(raw2, tool_choice="auto", tools=tools) + + assert len(translator.call_id_to_raw_text) == 2 + assert tc1[0].id in translator.call_id_to_raw_text + assert tc2[0].id in translator.call_id_to_raw_text + + +# --------------------------------------------------------------------------- +# Inbound: OpenAI messages → raw text +# --------------------------------------------------------------------------- + + +class TestReconstructRawText: + def test_reconstruct_from_lookup(self, translator): + # First, parse to populate lookup + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + tools = [{"type": "function", "function": {"name": "search"}}] + _, tool_calls, _ = translator.parse_model_output( + raw, tool_choice="auto", tools=tools + ) + + # Now reconstruct + tc_dicts = [tc.model_dump() for tc in tool_calls] + reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts) + + assert reconstructed == raw + + def test_reconstruct_fallback_without_lookup(self, translator): + # Reconstruct without having parsed first — uses fallback + tc_dicts = [ + { + "id": "fake-id-123", + "type": "function", + "function": {"name": "search", "arguments": '{"query": "cats"}'}, + } + ] + reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts) + + assert "" in reconstructed + assert "search" in reconstructed + assert "cats" in reconstructed + + def test_reconstruct_empty_list(self, translator): + assert translator.reconstruct_raw_text_from_tool_calls([]) == "" + + def test_reconstruct_multiple_tool_calls(self, translator): + tc_dicts = [ + { + "id": "id-1", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "SF"}'}, + }, + { + "id": "id-2", + "type": "function", + "function": {"name": "get_time", "arguments": '{"tz": "PST"}'}, + }, + ] + reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts) + + assert reconstructed.count("") == 2 + assert "get_weather" in reconstructed + assert "get_time" in reconstructed + + +# --------------------------------------------------------------------------- +# Message conversion +# --------------------------------------------------------------------------- + + +class TestConvertMessages: + def test_regular_messages_pass_through(self, translator): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi there"}, + ] + result = translator.convert_messages_for_template(messages) + + assert result == messages + + def test_assistant_with_tool_calls_reconstructed(self, translator): + # Parse first to populate lookup + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + tools = [{"type": "function", "function": {"name": "search"}}] + _, tool_calls, _ = translator.parse_model_output( + raw, tool_choice="auto", tools=tools + ) + + messages = [ + {"role": "user", "content": "Search for cats"}, + { + "role": "assistant", + "content": None, + "tool_calls": [tc.model_dump() for tc in tool_calls], + }, + { + "role": "tool", + "tool_call_id": tool_calls[0].id, + "content": "Found 5 cats", + }, + ] + + result = translator.convert_messages_for_template(messages) + + # User message unchanged + assert result[0] == messages[0] + # Assistant message reconstructed to raw text + assert result[1]["role"] == "assistant" + assert "" in result[1]["content"] + assert "tool_calls" not in result[1] + # Tool message passed through + assert result[2] == messages[2] + + def test_assistant_with_content_and_tool_calls(self, translator): + messages = [ + { + "role": "assistant", + "content": "Let me search.", + "tool_calls": [ + { + "id": "fake-id", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "x"}'}, + } + ], + }, + ] + + result = translator.convert_messages_for_template(messages) + + assert result[0]["role"] == "assistant" + assert "Let me search." in result[0]["content"] + assert "" in result[0]["content"] + + def test_mixed_message_types(self, translator): + """Only tool_call assistant messages are reconstructed.""" + messages = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, # regular, no tool_calls + {"role": "user", "content": "Search cats"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "tc-1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "cats"}'}, + } + ], + }, + {"role": "tool", "tool_call_id": "tc-1", "content": "5 results"}, + {"role": "assistant", "content": "Found 5 cats!"}, # regular again + ] + + result = translator.convert_messages_for_template(messages) + + # Messages at indices 0, 1, 2, 4, 5 should be unchanged + assert result[0] == messages[0] + assert result[1] == messages[1] + assert result[2] == messages[2] + assert result[4] == messages[4] + assert result[5] == messages[5] + # Message at index 3 should be reconstructed + assert "" in result[3]["content"] + + +# --------------------------------------------------------------------------- +# Roundtrip +# --------------------------------------------------------------------------- + + +class TestRoundtrip: + def test_single_tool_call_roundtrip(self, translator): + raw = ( + '{"name": "search", "arguments": {"query": "cats"}}' + ) + tools = [{"type": "function", "function": {"name": "search"}}] + + # Parse + _, tool_calls, _ = translator.parse_model_output( + raw, tool_choice="auto", tools=tools + ) + # Reconstruct + tc_dicts = [tc.model_dump() for tc in tool_calls] + reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts) + + assert reconstructed == raw + + def test_tool_call_empty_arguments(self, translator): + raw = '{"name": "list_all", "arguments": {}}' + tools = [{"type": "function", "function": {"name": "list_all"}}] + + _, tool_calls, _ = translator.parse_model_output( + raw, tool_choice="auto", tools=tools + ) + assert tool_calls is not None + assert json.loads(tool_calls[0].function.arguments) == {} + + +# --------------------------------------------------------------------------- +# Decode with tool awareness +# --------------------------------------------------------------------------- + + +class TestDecodeToolAwareness: + def test_decode_without_tools(self, translator): + tokens = [72, 101, 108, 108, 111] # "Hello" + text = translator.decode_with_tool_awareness(tokens, has_tools=False) + assert text == "Hello" + + def test_decode_with_tools_preserves_special(self, translator): + # With the mock tokenizer there are no "special" tokens to strip, + # but verify the flag is passed correctly + tokens = [72, 101, 108, 108, 111] + text = translator.decode_with_tool_awareness(tokens, has_tools=True) + assert text == "Hello" + + def test_decode_strips_bos_without_tools(self, translator): + tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello" + text = translator.decode_with_tool_awareness(tokens, has_tools=False) + assert text == "Hello" # BOS stripped + + def test_decode_keeps_bos_with_tools(self, translator): + tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello" + text = translator.decode_with_tool_awareness(tokens, has_tools=True) + # BOS (chr(1)) is not printable so mock tokenizer returns "" for it + # but the flag skip_special_tokens=False is passed + assert "Hello" in text diff --git a/pyproject.toml b/pyproject.toml index d291021b..64586f24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,12 +46,14 @@ rewardfns = [ "torch" ] example_trainer = [ - "atroposlib[rewardfns]", - "vllm", + "atroposlib[rewardfns, openai_endpoint]", "accelerate", "peft", "requests", ] +openai_endpoint = [ + "vllm", +] dev = [ "pytest", "pytest-asyncio",