mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
b8120e8eae
commit
d4589e1107
3 changed files with 69 additions and 59 deletions
|
|
@ -296,18 +296,19 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
async def openai_chat_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible chat completions endpoint.
|
||||
|
||||
|
||||
This is a thin wrapper around our /generate endpoint that formats
|
||||
the request/response to match OpenAI's chat completions API.
|
||||
|
||||
|
||||
Used by atroposlib/GSM8k environment for rollout generation.
|
||||
"""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
import time as time_module
|
||||
|
||||
request_dict = await request.json()
|
||||
|
||||
|
||||
# Extract parameters
|
||||
model = request_dict.get("model", "")
|
||||
messages = request_dict.get("messages", [])
|
||||
|
|
@ -318,13 +319,11 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
stop = request_dict.get("stop", None)
|
||||
presence_penalty = request_dict.get("presence_penalty", 0.0)
|
||||
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
|
||||
|
||||
|
||||
# Convert messages to prompt using tokenizer's chat template
|
||||
try:
|
||||
prompt = engine.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: simple concatenation if no chat template
|
||||
|
|
@ -334,7 +333,7 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
content = msg.get("content", "")
|
||||
prompt += f"{role}: {content}\n"
|
||||
prompt += "assistant: "
|
||||
|
||||
|
||||
# Build sampling params (reusing our existing logic)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -345,32 +344,34 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
)
|
||||
|
||||
|
||||
request_id = random_uuid()
|
||||
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(prompt, sampling_params, request_id):
|
||||
final_output = request_output
|
||||
|
||||
|
||||
if final_output is None:
|
||||
raise HTTPException(status_code=500, detail="Generation failed")
|
||||
|
||||
|
||||
# Build choices in OpenAI chat format
|
||||
choices = []
|
||||
for idx, output in enumerate(final_output.outputs):
|
||||
choices.append({
|
||||
"index": idx,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": output.text,
|
||||
},
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
})
|
||||
|
||||
choices.append(
|
||||
{
|
||||
"index": idx,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": output.text,
|
||||
},
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
)
|
||||
|
||||
# Build response
|
||||
prompt_tokens = len(final_output.prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in final_output.outputs)
|
||||
|
||||
|
||||
response = {
|
||||
"id": f"chatcmpl-{random_uuid()}",
|
||||
"object": "chat.completion",
|
||||
|
|
@ -383,7 +384,7 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
@ -391,15 +392,16 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
async def openai_completions(request: Request) -> Response:
|
||||
"""
|
||||
OpenAI-compatible text completions endpoint.
|
||||
|
||||
|
||||
This is the non-chat version of completions (raw text in, text out).
|
||||
"""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
import time as time_module
|
||||
|
||||
request_dict = await request.json()
|
||||
|
||||
|
||||
# Extract parameters
|
||||
model = request_dict.get("model", "")
|
||||
prompt = request_dict.get("prompt", "")
|
||||
|
|
@ -411,10 +413,10 @@ async def openai_completions(request: Request) -> Response:
|
|||
presence_penalty = request_dict.get("presence_penalty", 0.0)
|
||||
frequency_penalty = request_dict.get("frequency_penalty", 0.0)
|
||||
logprobs_requested = request_dict.get("logprobs", None)
|
||||
|
||||
|
||||
# Handle single prompt or list of prompts
|
||||
prompts = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
# Build sampling params
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -426,50 +428,54 @@ async def openai_completions(request: Request) -> Response:
|
|||
frequency_penalty=frequency_penalty,
|
||||
logprobs=logprobs_requested,
|
||||
)
|
||||
|
||||
|
||||
# Generate for each prompt
|
||||
all_choices = []
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
|
||||
for prompt_text in prompts:
|
||||
request_id = random_uuid()
|
||||
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(prompt_text, sampling_params, request_id):
|
||||
async for request_output in engine.generate(
|
||||
prompt_text, sampling_params, request_id
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
|
||||
if final_output is None:
|
||||
raise HTTPException(status_code=500, detail="Generation failed")
|
||||
|
||||
|
||||
# Count tokens
|
||||
total_prompt_tokens += len(final_output.prompt_token_ids)
|
||||
|
||||
|
||||
# Build choices
|
||||
for output in final_output.outputs:
|
||||
total_completion_tokens += len(output.token_ids)
|
||||
|
||||
|
||||
choice = {
|
||||
"text": output.text,
|
||||
"index": len(all_choices),
|
||||
"logprobs": None,
|
||||
"finish_reason": output.finish_reason or "stop",
|
||||
}
|
||||
|
||||
|
||||
# Add logprobs if requested
|
||||
if logprobs_requested is not None and output.logprobs:
|
||||
choice["logprobs"] = {
|
||||
"tokens": [engine.tokenizer.decode([tid]) for tid in output.token_ids],
|
||||
"tokens": [
|
||||
engine.tokenizer.decode([tid]) for tid in output.token_ids
|
||||
],
|
||||
"token_logprobs": [
|
||||
list(lp.values())[0].logprob if lp else None
|
||||
list(lp.values())[0].logprob if lp else None
|
||||
for lp in output.logprobs
|
||||
],
|
||||
"top_logprobs": None,
|
||||
"text_offset": [],
|
||||
}
|
||||
|
||||
|
||||
all_choices.append(choice)
|
||||
|
||||
|
||||
# Build response in OpenAI format
|
||||
response = {
|
||||
"id": f"cmpl-{random_uuid()}",
|
||||
|
|
@ -483,7 +489,7 @@ async def openai_completions(request: Request) -> Response:
|
|||
"total_tokens": total_prompt_tokens + total_completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue