mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
changes
This commit is contained in:
parent
79842edba7
commit
2b240bbd2e
1 changed files with 299 additions and 0 deletions
|
|
@ -259,6 +259,305 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
return JSONResponse(ret)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI-Compatible Completions Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible completions endpoint.
|
||||
|
||||
This translates OpenAI API format to our internal format.
|
||||
|
||||
Request JSON fields (OpenAI format):
|
||||
- model: str - Model name (ignored, uses loaded model)
|
||||
- prompt: str or List[str] - The input text(s) to complete
|
||||
- max_tokens: int - Maximum tokens to generate
|
||||
- temperature: float - Sampling temperature
|
||||
- top_p: float - Nucleus sampling threshold
|
||||
- n: int - Number of completions per prompt
|
||||
- stream: bool - Whether to stream results
|
||||
- logprobs: int - Number of logprobs to return
|
||||
- echo: bool - Whether to echo the prompt
|
||||
- stop: str or List[str] - Stop sequences
|
||||
|
||||
Returns OpenAI-compatible response format.
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
|
||||
# Extract OpenAI-specific fields
|
||||
prompt = request_dict.get("prompt", "")
|
||||
model = request_dict.get("model", "")
|
||||
max_tokens = request_dict.get("max_tokens", 16)
|
||||
temperature = request_dict.get("temperature", 1.0)
|
||||
top_p = request_dict.get("top_p", 1.0)
|
||||
n = request_dict.get("n", 1)
|
||||
stream = request_dict.get("stream", False)
|
||||
logprobs_count = request_dict.get("logprobs")
|
||||
echo = request_dict.get("echo", False)
|
||||
stop = request_dict.get("stop")
|
||||
|
||||
# Handle prompt as string or list
|
||||
if isinstance(prompt, list):
|
||||
# For simplicity, just use the first prompt
|
||||
# Full implementation would handle batches
|
||||
prompt = prompt[0] if prompt else ""
|
||||
|
||||
# Build sampling params
|
||||
sampling_kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
}
|
||||
|
||||
if logprobs_count is not None:
|
||||
sampling_kwargs["logprobs"] = logprobs_count
|
||||
|
||||
if stop is not None:
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
sampling_kwargs["stop"] = stop
|
||||
|
||||
sampling_params = SamplingParams(**sampling_kwargs)
|
||||
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Non-streaming response
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
if final_output is None:
|
||||
return JSONResponse(
|
||||
{"error": {"message": "No output generated", "type": "server_error"}},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Build OpenAI-compatible response
|
||||
choices = []
|
||||
for i, output in enumerate(final_output.outputs):
|
||||
text = output.text
|
||||
if echo:
|
||||
text = prompt + text
|
||||
|
||||
choice = {
|
||||
"text": text,
|
||||
"index": i,
|
||||
"logprobs": None,
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
|
||||
# Add logprobs if requested
|
||||
if logprobs_count is not None and output.logprobs:
|
||||
choice["logprobs"] = {
|
||||
"tokens": [
|
||||
list(lp.keys())[0] if lp else "" for lp in output.logprobs
|
||||
],
|
||||
"token_logprobs": [
|
||||
list(lp.values())[0].logprob if lp else None
|
||||
for lp in output.logprobs
|
||||
],
|
||||
"top_logprobs": [
|
||||
{k: v.logprob for k, v in lp.items()} if lp else {}
|
||||
for lp in output.logprobs
|
||||
],
|
||||
"text_offset": [], # Not implemented
|
||||
}
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
response = {
|
||||
"id": f"cmpl-{request_id}",
|
||||
"object": "text_completion",
|
||||
"created": int(asyncio.get_event_loop().time()),
|
||||
"model": model or "vllm-model",
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0,
|
||||
"completion_tokens": sum(len(o.token_ids) for o in final_output.outputs),
|
||||
"total_tokens": (len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0)
|
||||
+ sum(len(o.token_ids) for o in final_output.outputs),
|
||||
},
|
||||
}
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible chat completions endpoint.
|
||||
|
||||
Request JSON fields:
|
||||
- model: str - Model name (ignored, uses loaded model)
|
||||
- messages: List[dict] - Chat messages with 'role' and 'content'
|
||||
- max_tokens: int - Maximum tokens to generate
|
||||
- temperature: float - Sampling temperature
|
||||
- top_p: float - Nucleus sampling threshold
|
||||
- n: int - Number of completions
|
||||
- stream: bool - Whether to stream results
|
||||
- stop: str or List[str] - Stop sequences
|
||||
|
||||
Returns OpenAI-compatible chat completion response.
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
|
||||
# Extract fields
|
||||
messages = request_dict.get("messages", [])
|
||||
model = request_dict.get("model", "")
|
||||
max_tokens = request_dict.get("max_tokens", 512)
|
||||
temperature = request_dict.get("temperature", 1.0)
|
||||
top_p = request_dict.get("top_p", 1.0)
|
||||
n = request_dict.get("n", 1)
|
||||
stream = request_dict.get("stream", False)
|
||||
stop = request_dict.get("stop")
|
||||
|
||||
# Convert messages to prompt using chat template
|
||||
assert engine is not None
|
||||
|
||||
# Try to use the tokenizer's chat template
|
||||
try:
|
||||
tokenizer = engine.tokenizer.tokenizer
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Fallback: simple concatenation
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
except Exception:
|
||||
# Simple fallback
|
||||
prompt = "\n".join(
|
||||
f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages
|
||||
)
|
||||
prompt += "\nassistant:"
|
||||
|
||||
# Build sampling params
|
||||
sampling_kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
}
|
||||
|
||||
if stop is not None:
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
sampling_kwargs["stop"] = stop
|
||||
|
||||
sampling_params = SamplingParams(**sampling_kwargs)
|
||||
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
request_id = random_uuid()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Non-streaming response
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
if final_output is None:
|
||||
return JSONResponse(
|
||||
{"error": {"message": "No output generated", "type": "server_error"}},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Build OpenAI-compatible chat response
|
||||
choices = []
|
||||
for i, output in enumerate(final_output.outputs):
|
||||
choice = {
|
||||
"index": i,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": output.text,
|
||||
},
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0
|
||||
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
|
||||
|
||||
response = {
|
||||
"id": f"chatcmpl-{request_id}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model or "vllm-model",
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> JSONResponse:
|
||||
"""
|
||||
List available models (OpenAI-compatible).
|
||||
|
||||
Returns the currently loaded model.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown"
|
||||
|
||||
return JSONResponse({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "vllm",
|
||||
"permission": [],
|
||||
"root": model_name,
|
||||
"parent": None,
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
@app.get("/v1/models/{model_id}")
|
||||
async def get_model(model_id: str) -> JSONResponse:
|
||||
"""
|
||||
Get model info (OpenAI-compatible).
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown"
|
||||
|
||||
return JSONResponse({
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "vllm",
|
||||
"permission": [],
|
||||
"root": model_name,
|
||||
"parent": None,
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bridge Endpoints (for shared-weight training)
|
||||
# =============================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue