[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-01-19 05:25:49 +00:00 committed by Jai Suphavadeeprasit
parent b8120e8eae
commit d4589e1107
3 changed files with 69 additions and 59 deletions

View file

@ -296,18 +296,19 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
async def openai_chat_completions(request: Request) -> Response:
"""
OpenAI-compatible chat completions endpoint.
This is a thin wrapper around our /generate endpoint that formats
the request/response to match OpenAI's chat completions API.
Used by atroposlib/GSM8k environment for rollout generation.
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
messages = request_dict.get("messages", [])
@ -318,13 +319,11 @@ async def openai_chat_completions(request: Request) -> Response:
stop = request_dict.get("stop", None)
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
# Convert messages to prompt using tokenizer's chat template
try:
prompt = engine.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
# Fallback: simple concatenation if no chat template
@ -334,7 +333,7 @@ async def openai_chat_completions(request: Request) -> Response:
content = msg.get("content", "")
prompt += f"{role}: {content}\n"
prompt += "assistant: "
# Build sampling params (reusing our existing logic)
sampling_params = SamplingParams(
max_tokens=max_tokens,
@ -345,32 +344,34 @@ async def openai_chat_completions(request: Request) -> Response:
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(prompt, sampling_params, request_id):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Build choices in OpenAI chat format
choices = []
for idx, output in enumerate(final_output.outputs):
choices.append({
"index": idx,
"message": {
"role": "assistant",
"content": output.text,
},
"finish_reason": output.finish_reason or "stop",
})
choices.append(
{
"index": idx,
"message": {
"role": "assistant",
"content": output.text,
},
"finish_reason": output.finish_reason or "stop",
}
)
# Build response
prompt_tokens = len(final_output.prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
@ -383,7 +384,7 @@ async def openai_chat_completions(request: Request) -> Response:
"total_tokens": prompt_tokens + completion_tokens,
},
}
return JSONResponse(response)
@ -391,15 +392,16 @@ async def openai_chat_completions(request: Request) -> Response:
async def openai_completions(request: Request) -> Response:
"""
OpenAI-compatible text completions endpoint.
This is the non-chat version of completions (raw text in, text out).
"""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
import time as time_module
request_dict = await request.json()
# Extract parameters
model = request_dict.get("model", "")
prompt = request_dict.get("prompt", "")
@ -411,10 +413,10 @@ async def openai_completions(request: Request) -> Response:
presence_penalty = request_dict.get("presence_penalty", 0.0)
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
logprobs_requested = request_dict.get("logprobs", None)
# Handle single prompt or list of prompts
prompts = [prompt] if isinstance(prompt, str) else prompt
# Build sampling params
sampling_params = SamplingParams(
max_tokens=max_tokens,
@ -426,50 +428,54 @@ async def openai_completions(request: Request) -> Response:
frequency_penalty=frequency_penalty,
logprobs=logprobs_requested,
)
# Generate for each prompt
all_choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for prompt_text in prompts:
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(prompt_text, sampling_params, request_id):
async for request_output in engine.generate(
prompt_text, sampling_params, request_id
):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
# Count tokens
total_prompt_tokens += len(final_output.prompt_token_ids)
# Build choices
for output in final_output.outputs:
total_completion_tokens += len(output.token_ids)
choice = {
"text": output.text,
"index": len(all_choices),
"logprobs": None,
"finish_reason": output.finish_reason or "stop",
}
# Add logprobs if requested
if logprobs_requested is not None and output.logprobs:
choice["logprobs"] = {
"tokens": [engine.tokenizer.decode([tid]) for tid in output.token_ids],
"tokens": [
engine.tokenizer.decode([tid]) for tid in output.token_ids
],
"token_logprobs": [
list(lp.values())[0].logprob if lp else None
list(lp.values())[0].logprob if lp else None
for lp in output.logprobs
],
"top_logprobs": None,
"text_offset": [],
}
all_choices.append(choice)
# Build response in OpenAI format
response = {
"id": f"cmpl-{random_uuid()}",
@ -483,7 +489,7 @@ async def openai_completions(request: Request) -> Response:
"total_tokens": total_prompt_tokens + total_completion_tokens,
},
}
return JSONResponse(response)