This commit is contained in:
Jai Suphavadeeprasit 2025-12-08 10:42:39 -05:00
parent 79842edba7
commit 2b240bbd2e

View file

@ -259,6 +259,305 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
return JSONResponse(ret)
# =============================================================================
# OpenAI-Compatible Completions Endpoint
# =============================================================================
@app.post("/v1/completions")
async def openai_completions(request: Request) -> Response:
"""
OpenAI-compatible completions endpoint.
This translates OpenAI API format to our internal format.
Request JSON fields (OpenAI format):
- model: str - Model name (ignored, uses loaded model)
- prompt: str or List[str] - The input text(s) to complete
- max_tokens: int - Maximum tokens to generate
- temperature: float - Sampling temperature
- top_p: float - Nucleus sampling threshold
- n: int - Number of completions per prompt
- stream: bool - Whether to stream results
- logprobs: int - Number of logprobs to return
- echo: bool - Whether to echo the prompt
- stop: str or List[str] - Stop sequences
Returns OpenAI-compatible response format.
"""
request_dict = await request.json()
# Extract OpenAI-specific fields
prompt = request_dict.get("prompt", "")
model = request_dict.get("model", "")
max_tokens = request_dict.get("max_tokens", 16)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stream = request_dict.get("stream", False)
logprobs_count = request_dict.get("logprobs")
echo = request_dict.get("echo", False)
stop = request_dict.get("stop")
# Handle prompt as string or list
if isinstance(prompt, list):
# For simplicity, just use the first prompt
# Full implementation would handle batches
prompt = prompt[0] if prompt else ""
# Build sampling params
sampling_kwargs = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"n": n,
}
if logprobs_count is not None:
sampling_kwargs["logprobs"] = logprobs_count
if stop is not None:
if isinstance(stop, str):
stop = [stop]
sampling_kwargs["stop"] = stop
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Non-streaming response
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
if final_output is None:
return JSONResponse(
{"error": {"message": "No output generated", "type": "server_error"}},
status_code=500,
)
# Build OpenAI-compatible response
choices = []
for i, output in enumerate(final_output.outputs):
text = output.text
if echo:
text = prompt + text
choice = {
"text": text,
"index": i,
"logprobs": None,
"finish_reason": output.finish_reason or "stop",
}
# Add logprobs if requested
if logprobs_count is not None and output.logprobs:
choice["logprobs"] = {
"tokens": [
list(lp.keys())[0] if lp else "" for lp in output.logprobs
],
"token_logprobs": [
list(lp.values())[0].logprob if lp else None
for lp in output.logprobs
],
"top_logprobs": [
{k: v.logprob for k, v in lp.items()} if lp else {}
for lp in output.logprobs
],
"text_offset": [], # Not implemented
}
choices.append(choice)
response = {
"id": f"cmpl-{request_id}",
"object": "text_completion",
"created": int(asyncio.get_event_loop().time()),
"model": model or "vllm-model",
"choices": choices,
"usage": {
"prompt_tokens": len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0,
"completion_tokens": sum(len(o.token_ids) for o in final_output.outputs),
"total_tokens": (len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0)
+ sum(len(o.token_ids) for o in final_output.outputs),
},
}
return JSONResponse(response)
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request) -> Response:
"""
OpenAI-compatible chat completions endpoint.
Request JSON fields:
- model: str - Model name (ignored, uses loaded model)
- messages: List[dict] - Chat messages with 'role' and 'content'
- max_tokens: int - Maximum tokens to generate
- temperature: float - Sampling temperature
- top_p: float - Nucleus sampling threshold
- n: int - Number of completions
- stream: bool - Whether to stream results
- stop: str or List[str] - Stop sequences
Returns OpenAI-compatible chat completion response.
"""
request_dict = await request.json()
# Extract fields
messages = request_dict.get("messages", [])
model = request_dict.get("model", "")
max_tokens = request_dict.get("max_tokens", 512)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stream = request_dict.get("stream", False)
stop = request_dict.get("stop")
# Convert messages to prompt using chat template
assert engine is not None
# Try to use the tokenizer's chat template
try:
tokenizer = engine.tokenizer.tokenizer
if hasattr(tokenizer, "apply_chat_template"):
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
# Fallback: simple concatenation
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
except Exception:
# Simple fallback
prompt = "\n".join(
f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages
)
prompt += "\nassistant:"
# Build sampling params
sampling_kwargs = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"n": n,
}
if stop is not None:
if isinstance(stop, str):
stop = [stop]
sampling_kwargs["stop"] = stop
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
# Non-streaming response
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
if final_output is None:
return JSONResponse(
{"error": {"message": "No output generated", "type": "server_error"}},
status_code=500,
)
# Build OpenAI-compatible chat response
choices = []
for i, output in enumerate(final_output.outputs):
choice = {
"index": i,
"message": {
"role": "assistant",
"content": output.text,
},
"finish_reason": output.finish_reason or "stop",
}
choices.append(choice)
prompt_tokens = len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
response = {
"id": f"chatcmpl-{request_id}",
"object": "chat.completion",
"created": int(time.time()),
"model": model or "vllm-model",
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
return JSONResponse(response)
@app.get("/v1/models")
async def list_models() -> JSONResponse:
"""
List available models (OpenAI-compatible).
Returns the currently loaded model.
"""
assert engine is not None
model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown"
return JSONResponse({
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "vllm",
"permission": [],
"root": model_name,
"parent": None,
}
],
})
@app.get("/v1/models/{model_id}")
async def get_model(model_id: str) -> JSONResponse:
"""
Get model info (OpenAI-compatible).
"""
assert engine is not None
model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown"
return JSONResponse({
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "vllm",
"permission": [],
"root": model_name,
"parent": None,
})
# =============================================================================
# Bridge Endpoints (for shared-weight training)
# =============================================================================