completions

This commit is contained in:
Jai Suphavadeeprasit 2026-01-19 00:12:07 -05:00
parent f57ef091aa
commit b8120e8eae

View file

@ -287,6 +287,206 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
return JSONResponse(ret)
# =============================================================================
# OpenAI-Compatible Chat Completions Endpoint
# =============================================================================
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request) -> Response:
"""
OpenAI-compatible chat completions endpoint.
This is a thin wrapper around our /generate endpoint that formats
the request/response to match OpenAI's chat completions API.
Used by atroposlib/GSM8k environment for rollout generation.
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
messages = request_dict.get("messages", [])
max_tokens = request_dict.get("max_tokens", 256)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stop = request_dict.get("stop", None)
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
# Convert messages to prompt using tokenizer's chat template
try:
prompt = engine.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception:
# Fallback: simple concatenation if no chat template
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt += f"{role}: {content}\n"
prompt += "assistant: "
# Build sampling params (reusing our existing logic)
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(prompt, sampling_params, request_id):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Build choices in OpenAI chat format
choices = []
for idx, output in enumerate(final_output.outputs):
choices.append({
"index": idx,
"message": {
"role": "assistant",
"content": output.text,
},
"finish_reason": output.finish_reason or "stop",
})
# Build response
prompt_tokens = len(final_output.prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(time_module.time()),
"model": model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
return JSONResponse(response)
@app.post("/v1/completions")
async def openai_completions(request: Request) -> Response:
"""
OpenAI-compatible text completions endpoint.
This is the non-chat version of completions (raw text in, text out).
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
prompt = request_dict.get("prompt", "")
max_tokens = request_dict.get("max_tokens", 256)
temperature = request_dict.get("temperature", 1.0)
top_p = request_dict.get("top_p", 1.0)
n = request_dict.get("n", 1)
stop = request_dict.get("stop", None)
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
logprobs_requested = request_dict.get("logprobs", None)
# Handle single prompt or list of prompts
prompts = [prompt] if isinstance(prompt, str) else prompt
# Build sampling params
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logprobs=logprobs_requested,
)
# Generate for each prompt
all_choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for prompt_text in prompts:
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(prompt_text, sampling_params, request_id):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Count tokens
total_prompt_tokens += len(final_output.prompt_token_ids)
# Build choices
for output in final_output.outputs:
total_completion_tokens += len(output.token_ids)
choice = {
"text": output.text,
"index": len(all_choices),
"logprobs": None,
"finish_reason": output.finish_reason or "stop",
}
# Add logprobs if requested
if logprobs_requested is not None and output.logprobs:
choice["logprobs"] = {
"tokens": [engine.tokenizer.decode([tid]) for tid in output.token_ids],
"token_logprobs": [
list(lp.values())[0].logprob if lp else None
for lp in output.logprobs
],
"top_logprobs": None,
"text_offset": [],
}
all_choices.append(choice)
# Build response in OpenAI format
response = {
"id": f"cmpl-{random_uuid()}",
"object": "text_completion",
"created": int(time_module.time()),
"model": model,
"choices": all_choices,
"usage": {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_prompt_tokens + total_completion_tokens,
},
}
return JSONResponse(response)
# =============================================================================
# Bridge Endpoints (Weight Synchronization)
# =============================================================================