generate endpoint with logprobs

This commit is contained in:
Jai Suphavadeeprasit 2025-12-08 10:58:14 -05:00
parent 2b240bbd2e
commit f5c847d39d

View file

@ -203,8 +203,28 @@ async def generate(request: Request) -> Response:
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
prompt_input = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
# Handle both string prompts and {"prompt_token_ids": [...]} format
# The latter is used by atroposlib's VLLMServer
if isinstance(prompt_input, dict) and "prompt_token_ids" in prompt_input:
# Token IDs format from atroposlib
prompt_token_ids = prompt_input["prompt_token_ids"]
prompt = {"prompt_token_ids": prompt_token_ids}
else:
# String prompt
prompt = prompt_input
# Handle logprobs parameter - atroposlib sends logprobs=0 which means "return logprobs"
# vLLM uses None to mean "don't return logprobs" and an int for "return N top logprobs"
if "logprobs" in request_dict:
logprobs_val = request_dict["logprobs"]
# logprobs=0 means return logprobs (just 1 per token)
# logprobs=None or not present means don't return logprobs
if logprobs_val is not None:
request_dict["logprobs"] = max(1, logprobs_val) # At least 1
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
@ -215,9 +235,9 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
# Streaming: yield results as they're generated
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in request_output.outputs]
prompt_text = request_output.prompt
assert prompt_text is not None
text_outputs = [prompt_text + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
@ -233,27 +253,33 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt or engine.tokenizer.decode(
prompt_text = final_output.prompt or engine.tokenizer.decode(
final_output.prompt_token_ids
)
assert prompt is not None
assert prompt_text is not None
text_outputs = [output.text for output in final_output.outputs]
finish_reasons = [output.finish_reason for output in final_output.outputs]
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons}
# Include logprobs if requested (useful for RL training)
# Format matches what atroposlib's VLLMServer expects
if sampling_params.logprobs is not None:
output_logprobs = [
[
[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs
]
for x in final_output.outputs
]
output_logprobs = []
for x in final_output.outputs:
if x.logprobs:
# Format: [[{token_id: logprob}, ...], ...] per output
seq_logprobs = [
[{str(key): value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs
]
else:
seq_logprobs = []
output_logprobs.append(seq_logprobs)
prompt_token_ids = final_output.prompt_token_ids
output_token_ids = [x.token_ids for x in final_output.outputs]
output_token_ids = [list(x.token_ids) for x in final_output.outputs]
ret["logprobs"] = output_logprobs
ret["prompt_token_ids"] = prompt_token_ids
ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else []
ret["token_ids"] = output_token_ids
return JSONResponse(ret)