diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2923098a..79325ea7 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -287,6 +287,206 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: return JSONResponse(ret) +# ============================================================================= +# OpenAI-Compatible Chat Completions Endpoint +# ============================================================================= + + +@app.post("/v1/chat/completions") +async def openai_chat_completions(request: Request) -> Response: + """ + OpenAI-compatible chat completions endpoint. + + This is a thin wrapper around our /generate endpoint that formats + the request/response to match OpenAI's chat completions API. + + Used by atroposlib/GSM8k environment for rollout generation. + """ + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + import time as time_module + request_dict = await request.json() + + # Extract parameters + model = request_dict.get("model", "") + messages = request_dict.get("messages", []) + max_tokens = request_dict.get("max_tokens", 256) + temperature = request_dict.get("temperature", 1.0) + top_p = request_dict.get("top_p", 1.0) + n = request_dict.get("n", 1) + stop = request_dict.get("stop", None) + presence_penalty = request_dict.get("presence_penalty", 0.0) + frequency_penalty = request_dict.get("frequency_penalty", 0.0) + + # Convert messages to prompt using tokenizer's chat template + try: + prompt = engine.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + except Exception: + # Fallback: simple concatenation if no chat template + prompt = "" + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + prompt += f"{role}: {content}\n" + prompt += "assistant: " + + # Build sampling params (reusing our existing logic) + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + + request_id = random_uuid() + + final_output = None + async for request_output in engine.generate(prompt, sampling_params, request_id): + final_output = request_output + + if final_output is None: + raise HTTPException(status_code=500, detail="Generation failed") + + # Build choices in OpenAI chat format + choices = [] + for idx, output in enumerate(final_output.outputs): + choices.append({ + "index": idx, + "message": { + "role": "assistant", + "content": output.text, + }, + "finish_reason": output.finish_reason or "stop", + }) + + # Build response + prompt_tokens = len(final_output.prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in final_output.outputs) + + response = { + "id": f"chatcmpl-{random_uuid()}", + "object": "chat.completion", + "created": int(time_module.time()), + "model": model, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + return JSONResponse(response) + + +@app.post("/v1/completions") +async def openai_completions(request: Request) -> Response: + """ + OpenAI-compatible text completions endpoint. + + This is the non-chat version of completions (raw text in, text out). + """ + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + import time as time_module + request_dict = await request.json() + + # Extract parameters + model = request_dict.get("model", "") + prompt = request_dict.get("prompt", "") + max_tokens = request_dict.get("max_tokens", 256) + temperature = request_dict.get("temperature", 1.0) + top_p = request_dict.get("top_p", 1.0) + n = request_dict.get("n", 1) + stop = request_dict.get("stop", None) + presence_penalty = request_dict.get("presence_penalty", 0.0) + frequency_penalty = request_dict.get("frequency_penalty", 0.0) + logprobs_requested = request_dict.get("logprobs", None) + + # Handle single prompt or list of prompts + prompts = [prompt] if isinstance(prompt, str) else prompt + + # Build sampling params + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logprobs=logprobs_requested, + ) + + # Generate for each prompt + all_choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for prompt_text in prompts: + request_id = random_uuid() + + final_output = None + async for request_output in engine.generate(prompt_text, sampling_params, request_id): + final_output = request_output + + if final_output is None: + raise HTTPException(status_code=500, detail="Generation failed") + + # Count tokens + total_prompt_tokens += len(final_output.prompt_token_ids) + + # Build choices + for output in final_output.outputs: + total_completion_tokens += len(output.token_ids) + + choice = { + "text": output.text, + "index": len(all_choices), + "logprobs": None, + "finish_reason": output.finish_reason or "stop", + } + + # Add logprobs if requested + if logprobs_requested is not None and output.logprobs: + choice["logprobs"] = { + "tokens": [engine.tokenizer.decode([tid]) for tid in output.token_ids], + "token_logprobs": [ + list(lp.values())[0].logprob if lp else None + for lp in output.logprobs + ], + "top_logprobs": None, + "text_offset": [], + } + + all_choices.append(choice) + + # Build response in OpenAI format + response = { + "id": f"cmpl-{random_uuid()}", + "object": "text_completion", + "created": int(time_module.time()), + "model": model, + "choices": all_choices, + "usage": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + }, + } + + return JSONResponse(response) + + # ============================================================================= # Bridge Endpoints (Weight Synchronization) # =============================================================================