diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index d5e9b57d..ce77ed59 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -259,6 +259,305 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: return JSONResponse(ret) +# ============================================================================= +# OpenAI-Compatible Completions Endpoint +# ============================================================================= + + +@app.post("/v1/completions") +async def openai_completions(request: Request) -> Response: + """ + OpenAI-compatible completions endpoint. + + This translates OpenAI API format to our internal format. + + Request JSON fields (OpenAI format): + - model: str - Model name (ignored, uses loaded model) + - prompt: str or List[str] - The input text(s) to complete + - max_tokens: int - Maximum tokens to generate + - temperature: float - Sampling temperature + - top_p: float - Nucleus sampling threshold + - n: int - Number of completions per prompt + - stream: bool - Whether to stream results + - logprobs: int - Number of logprobs to return + - echo: bool - Whether to echo the prompt + - stop: str or List[str] - Stop sequences + + Returns OpenAI-compatible response format. + """ + request_dict = await request.json() + + # Extract OpenAI-specific fields + prompt = request_dict.get("prompt", "") + model = request_dict.get("model", "") + max_tokens = request_dict.get("max_tokens", 16) + temperature = request_dict.get("temperature", 1.0) + top_p = request_dict.get("top_p", 1.0) + n = request_dict.get("n", 1) + stream = request_dict.get("stream", False) + logprobs_count = request_dict.get("logprobs") + echo = request_dict.get("echo", False) + stop = request_dict.get("stop") + + # Handle prompt as string or list + if isinstance(prompt, list): + # For simplicity, just use the first prompt + # Full implementation would handle batches + prompt = prompt[0] if prompt else "" + + # Build sampling params + sampling_kwargs = { + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "n": n, + } + + if logprobs_count is not None: + sampling_kwargs["logprobs"] = logprobs_count + + if stop is not None: + if isinstance(stop, str): + stop = [stop] + sampling_kwargs["stop"] = stop + + sampling_params = SamplingParams(**sampling_kwargs) + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Non-streaming response + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + if final_output is None: + return JSONResponse( + {"error": {"message": "No output generated", "type": "server_error"}}, + status_code=500, + ) + + # Build OpenAI-compatible response + choices = [] + for i, output in enumerate(final_output.outputs): + text = output.text + if echo: + text = prompt + text + + choice = { + "text": text, + "index": i, + "logprobs": None, + "finish_reason": output.finish_reason or "stop", + } + + # Add logprobs if requested + if logprobs_count is not None and output.logprobs: + choice["logprobs"] = { + "tokens": [ + list(lp.keys())[0] if lp else "" for lp in output.logprobs + ], + "token_logprobs": [ + list(lp.values())[0].logprob if lp else None + for lp in output.logprobs + ], + "top_logprobs": [ + {k: v.logprob for k, v in lp.items()} if lp else {} + for lp in output.logprobs + ], + "text_offset": [], # Not implemented + } + + choices.append(choice) + + response = { + "id": f"cmpl-{request_id}", + "object": "text_completion", + "created": int(asyncio.get_event_loop().time()), + "model": model or "vllm-model", + "choices": choices, + "usage": { + "prompt_tokens": len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0, + "completion_tokens": sum(len(o.token_ids) for o in final_output.outputs), + "total_tokens": (len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0) + + sum(len(o.token_ids) for o in final_output.outputs), + }, + } + + return JSONResponse(response) + + +@app.post("/v1/chat/completions") +async def openai_chat_completions(request: Request) -> Response: + """ + OpenAI-compatible chat completions endpoint. + + Request JSON fields: + - model: str - Model name (ignored, uses loaded model) + - messages: List[dict] - Chat messages with 'role' and 'content' + - max_tokens: int - Maximum tokens to generate + - temperature: float - Sampling temperature + - top_p: float - Nucleus sampling threshold + - n: int - Number of completions + - stream: bool - Whether to stream results + - stop: str or List[str] - Stop sequences + + Returns OpenAI-compatible chat completion response. + """ + request_dict = await request.json() + + # Extract fields + messages = request_dict.get("messages", []) + model = request_dict.get("model", "") + max_tokens = request_dict.get("max_tokens", 512) + temperature = request_dict.get("temperature", 1.0) + top_p = request_dict.get("top_p", 1.0) + n = request_dict.get("n", 1) + stream = request_dict.get("stream", False) + stop = request_dict.get("stop") + + # Convert messages to prompt using chat template + assert engine is not None + + # Try to use the tokenizer's chat template + try: + tokenizer = engine.tokenizer.tokenizer + if hasattr(tokenizer, "apply_chat_template"): + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # Fallback: simple concatenation + prompt = "" + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" + prompt += "<|im_start|>assistant\n" + except Exception: + # Simple fallback + prompt = "\n".join( + f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages + ) + prompt += "\nassistant:" + + # Build sampling params + sampling_kwargs = { + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "n": n, + } + + if stop is not None: + if isinstance(stop, str): + stop = [stop] + sampling_kwargs["stop"] = stop + + sampling_params = SamplingParams(**sampling_kwargs) + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + request_id = random_uuid() + + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Non-streaming response + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + if final_output is None: + return JSONResponse( + {"error": {"message": "No output generated", "type": "server_error"}}, + status_code=500, + ) + + # Build OpenAI-compatible chat response + choices = [] + for i, output in enumerate(final_output.outputs): + choice = { + "index": i, + "message": { + "role": "assistant", + "content": output.text, + }, + "finish_reason": output.finish_reason or "stop", + } + choices.append(choice) + + prompt_tokens = len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0 + completion_tokens = sum(len(o.token_ids) for o in final_output.outputs) + + response = { + "id": f"chatcmpl-{request_id}", + "object": "chat.completion", + "created": int(time.time()), + "model": model or "vllm-model", + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + return JSONResponse(response) + + +@app.get("/v1/models") +async def list_models() -> JSONResponse: + """ + List available models (OpenAI-compatible). + + Returns the currently loaded model. + """ + assert engine is not None + + model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown" + + return JSONResponse({ + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "vllm", + "permission": [], + "root": model_name, + "parent": None, + } + ], + }) + + +@app.get("/v1/models/{model_id}") +async def get_model(model_id: str) -> JSONResponse: + """ + Get model info (OpenAI-compatible). + """ + assert engine is not None + + model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown" + + return JSONResponse({ + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "vllm", + "permission": [], + "root": model_name, + "parent": None, + }) + + # ============================================================================= # Bridge Endpoints (for shared-weight training) # =============================================================================