mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
remove bloated endpoints
This commit is contained in:
parent
672cdbaea8
commit
4f6f2faac8
5 changed files with 37 additions and 701 deletions
|
|
@ -114,8 +114,20 @@ from vllm.usage.usage_lib import UsageContext # noqa: E402
|
|||
from vllm.utils import random_uuid # noqa: E402
|
||||
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
|
||||
|
||||
# Handle vLLM version differences - FlexibleArgumentParser was removed/renamed
|
||||
try:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
except ImportError:
|
||||
# Use standard argparse for newer vLLM versions
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402
|
||||
# set_ulimit might not exist in all vLLM versions
|
||||
try:
|
||||
from vllm.utils import set_ulimit
|
||||
except ImportError:
|
||||
def set_ulimit() -> None:
|
||||
"""No-op fallback for set_ulimit."""
|
||||
pass
|
||||
|
||||
from vllm.outputs import RequestOutput # noqa: F401, E402
|
||||
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
||||
|
|
@ -312,220 +324,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
return JSONResponse(ret)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI-Compatible Chat Completions Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible chat completions endpoint.
|
||||
|
||||
This is a thin wrapper around our /generate endpoint that formats
|
||||
the request/response to match OpenAI's chat completions API.
|
||||
|
||||
Used by atroposlib/GSM8k environment for rollout generation.
|
||||
"""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
import time as time_module
|
||||
|
||||
request_dict = await request.json()
|
||||
|
||||
# Extract parameters
|
||||
model = request_dict.get("model", "")
|
||||
messages = request_dict.get("messages", [])
|
||||
max_tokens = request_dict.get("max_tokens", 256)
|
||||
temperature = request_dict.get("temperature", 1.0)
|
||||
top_p = request_dict.get("top_p", 1.0)
|
||||
n = request_dict.get("n", 1)
|
||||
stop = request_dict.get("stop", None)
|
||||
presence_penalty = request_dict.get("presence_penalty", 0.0)
|
||||
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
|
||||
|
||||
# Convert messages to prompt using tokenizer's chat template
|
||||
try:
|
||||
prompt = engine.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: simple concatenation if no chat template
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
prompt += f"{role}: {content}\n"
|
||||
prompt += "assistant: "
|
||||
|
||||
# Build sampling params (reusing our existing logic)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(
|
||||
prompt, sampling_params, request_id, lora_request=lora_request
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
if final_output is None:
|
||||
raise HTTPException(status_code=500, detail="Generation failed")
|
||||
|
||||
# Build choices in OpenAI chat format
|
||||
choices = []
|
||||
for idx, output in enumerate(final_output.outputs):
|
||||
choices.append(
|
||||
{
|
||||
"index": idx,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": output.text,
|
||||
},
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
)
|
||||
|
||||
# Build response
|
||||
prompt_tokens = len(final_output.prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
|
||||
|
||||
response = {
|
||||
"id": f"chatcmpl-{random_uuid()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time_module.time()),
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible text completions endpoint.
|
||||
|
||||
This is the non-chat version of completions (raw text in, text out).
|
||||
"""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
import time as time_module
|
||||
|
||||
request_dict = await request.json()
|
||||
|
||||
# Extract parameters
|
||||
model = request_dict.get("model", "")
|
||||
prompt = request_dict.get("prompt", "")
|
||||
max_tokens = request_dict.get("max_tokens", 256)
|
||||
temperature = request_dict.get("temperature", 1.0)
|
||||
top_p = request_dict.get("top_p", 1.0)
|
||||
n = request_dict.get("n", 1)
|
||||
stop = request_dict.get("stop", None)
|
||||
presence_penalty = request_dict.get("presence_penalty", 0.0)
|
||||
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
|
||||
logprobs_requested = request_dict.get("logprobs", None)
|
||||
|
||||
# Handle single prompt or list of prompts
|
||||
prompts = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Build sampling params
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logprobs=logprobs_requested,
|
||||
)
|
||||
|
||||
# Generate for each prompt
|
||||
all_choices = []
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
for prompt_text in prompts:
|
||||
request_id = random_uuid()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(
|
||||
prompt_text, sampling_params, request_id, lora_request=lora_request
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
if final_output is None:
|
||||
raise HTTPException(status_code=500, detail="Generation failed")
|
||||
|
||||
# Count tokens
|
||||
total_prompt_tokens += len(final_output.prompt_token_ids)
|
||||
|
||||
# Build choices
|
||||
for output in final_output.outputs:
|
||||
total_completion_tokens += len(output.token_ids)
|
||||
|
||||
choice = {
|
||||
"text": output.text,
|
||||
"index": len(all_choices),
|
||||
"logprobs": None,
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
|
||||
# Add logprobs if requested
|
||||
if logprobs_requested is not None and output.logprobs:
|
||||
choice["logprobs"] = {
|
||||
"tokens": [
|
||||
engine.tokenizer.decode([tid]) for tid in output.token_ids
|
||||
],
|
||||
"token_logprobs": [
|
||||
list(lp.values())[0].logprob if lp else None
|
||||
for lp in output.logprobs
|
||||
],
|
||||
"top_logprobs": None,
|
||||
"text_offset": [],
|
||||
}
|
||||
|
||||
all_choices.append(choice)
|
||||
|
||||
# Build response in OpenAI format
|
||||
response = {
|
||||
"id": f"cmpl-{random_uuid()}",
|
||||
"object": "text_completion",
|
||||
"created": int(time_module.time()),
|
||||
"model": model,
|
||||
"choices": all_choices,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_prompt_tokens + total_completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bridge Endpoints (Weight Synchronization)
|
||||
# =============================================================================
|
||||
|
|
@ -911,12 +709,16 @@ async def run_server(
|
|||
|
||||
# Log available endpoints
|
||||
logger.info("=" * 60)
|
||||
logger.info("Streamlined vLLM Server - Training-Focused API")
|
||||
logger.info("Available endpoints:")
|
||||
logger.info(" POST /generate - Generate completions")
|
||||
logger.info(" GET /bridge/info - Bridge status")
|
||||
logger.info(" POST /bridge/pause - Pause generation")
|
||||
logger.info(" POST /bridge/resume - Resume generation")
|
||||
logger.info(" GET /lora/status - LoRA adapter status")
|
||||
logger.info(" POST /generate - Generate with logprobs (primary endpoint)")
|
||||
logger.info(" GET /health - Health check")
|
||||
logger.info(" GET /bridge/info - Bridge status")
|
||||
logger.info(" POST /bridge/pause - Pause generation")
|
||||
logger.info(" POST /bridge/resume - Resume generation")
|
||||
logger.info(" GET /lora/status - LoRA adapter status")
|
||||
logger.info(" POST /lora/load - Load LoRA adapter")
|
||||
logger.info(" POST /lora/unload - Unload LoRA adapter")
|
||||
logger.info("=" * 60)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue