remove bloated endpoints

This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 09:46:55 -05:00
parent 672cdbaea8
commit 4f6f2faac8
5 changed files with 37 additions and 701 deletions

View file

@ -114,8 +114,20 @@ from vllm.usage.usage_lib import UsageContext # noqa: E402
from vllm.utils import random_uuid # noqa: E402
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
# Handle vLLM version differences - FlexibleArgumentParser was removed/renamed
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
# Use standard argparse for newer vLLM versions
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402
# set_ulimit might not exist in all vLLM versions
try:
from vllm.utils import set_ulimit
except ImportError:
def set_ulimit() -> None:
"""No-op fallback for set_ulimit."""
pass
from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
@ -312,220 +324,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
return JSONResponse(ret)
# =============================================================================
# OpenAI-Compatible Chat Completions Endpoint
# =============================================================================
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request) -> Response:
"""
OpenAI-compatible chat completions endpoint.
This is a thin wrapper around our /generate endpoint that formats
the request/response to match OpenAI's chat completions API.
Used by atroposlib/GSM8k environment for rollout generation.
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
messages = request_dict.get("messages", [])
max_tokens = request_dict.get("max_tokens", 256)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stop = request_dict.get("stop", None)
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
# Convert messages to prompt using tokenizer's chat template
try:
prompt = engine.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
# Fallback: simple concatenation if no chat template
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt += f"{role}: {content}\n"
prompt += "assistant: "
# Build sampling params (reusing our existing logic)
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
request_id = random_uuid()
# Get active LoRA adapter if any
lora_request = _get_lora_request()
final_output = None
async for request_output in engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Build choices in OpenAI chat format
choices = []
for idx, output in enumerate(final_output.outputs):
choices.append(
{
"index": idx,
"message": {
"role": "assistant",
"content": output.text,
},
"finish_reason": output.finish_reason or "stop",
}
)
# Build response
prompt_tokens = len(final_output.prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(time_module.time()),
"model": model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
return JSONResponse(response)
@app.post("/v1/completions")
async def openai_completions(request: Request) -> Response:
"""
OpenAI-compatible text completions endpoint.
This is the non-chat version of completions (raw text in, text out).
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
prompt = request_dict.get("prompt", "")
max_tokens = request_dict.get("max_tokens", 256)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stop = request_dict.get("stop", None)
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
logprobs_requested = request_dict.get("logprobs", None)
# Handle single prompt or list of prompts
prompts = [prompt] if isinstance(prompt, str) else prompt
# Build sampling params
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logprobs=logprobs_requested,
)
# Generate for each prompt
all_choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
# Get active LoRA adapter if any
lora_request = _get_lora_request()
for prompt_text in prompts:
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(
prompt_text, sampling_params, request_id, lora_request=lora_request
):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Count tokens
total_prompt_tokens += len(final_output.prompt_token_ids)
# Build choices
for output in final_output.outputs:
total_completion_tokens += len(output.token_ids)
choice = {
"text": output.text,
"index": len(all_choices),
"logprobs": None,
"finish_reason": output.finish_reason or "stop",
}
# Add logprobs if requested
if logprobs_requested is not None and output.logprobs:
choice["logprobs"] = {
"tokens": [
engine.tokenizer.decode([tid]) for tid in output.token_ids
],
"token_logprobs": [
list(lp.values())[0].logprob if lp else None
for lp in output.logprobs
],
"top_logprobs": None,
"text_offset": [],
}
all_choices.append(choice)
# Build response in OpenAI format
response = {
"id": f"cmpl-{random_uuid()}",
"object": "text_completion",
"created": int(time_module.time()),
"model": model,
"choices": all_choices,
"usage": {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_prompt_tokens + total_completion_tokens,
},
}
return JSONResponse(response)
# =============================================================================
# Bridge Endpoints (Weight Synchronization)
# =============================================================================
@ -911,12 +709,16 @@ async def run_server(
# Log available endpoints
logger.info("=" * 60)
logger.info("Streamlined vLLM Server - Training-Focused API")
logger.info("Available endpoints:")
logger.info(" POST /generate - Generate completions")
logger.info(" GET /bridge/info - Bridge status")
logger.info(" POST /bridge/pause - Pause generation")
logger.info(" POST /bridge/resume - Resume generation")
logger.info(" GET /lora/status - LoRA adapter status")
logger.info(" POST /generate - Generate with logprobs (primary endpoint)")
logger.info(" GET /health - Health check")
logger.info(" GET /bridge/info - Bridge status")
logger.info(" POST /bridge/pause - Pause generation")
logger.info(" POST /bridge/resume - Resume generation")
logger.info(" GET /lora/status - LoRA adapter status")
logger.info(" POST /lora/load - Load LoRA adapter")
logger.info(" POST /lora/unload - Unload LoRA adapter")
logger.info("=" * 60)
shutdown_task = await serve_http(