mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
generate endpoint with logprobs
This commit is contained in:
parent
2b240bbd2e
commit
f5c847d39d
1 changed files with 42 additions and 16 deletions
|
|
@ -203,8 +203,28 @@ async def generate(request: Request) -> Response:
|
|||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
prompt_input = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
|
||||
# Handle both string prompts and {"prompt_token_ids": [...]} format
|
||||
# The latter is used by atroposlib's VLLMServer
|
||||
if isinstance(prompt_input, dict) and "prompt_token_ids" in prompt_input:
|
||||
# Token IDs format from atroposlib
|
||||
prompt_token_ids = prompt_input["prompt_token_ids"]
|
||||
prompt = {"prompt_token_ids": prompt_token_ids}
|
||||
else:
|
||||
# String prompt
|
||||
prompt = prompt_input
|
||||
|
||||
# Handle logprobs parameter - atroposlib sends logprobs=0 which means "return logprobs"
|
||||
# vLLM uses None to mean "don't return logprobs" and an int for "return N top logprobs"
|
||||
if "logprobs" in request_dict:
|
||||
logprobs_val = request_dict["logprobs"]
|
||||
# logprobs=0 means return logprobs (just 1 per token)
|
||||
# logprobs=None or not present means don't return logprobs
|
||||
if logprobs_val is not None:
|
||||
request_dict["logprobs"] = max(1, logprobs_val) # At least 1
|
||||
|
||||
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
|
@ -215,9 +235,9 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
# Streaming: yield results as they're generated
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||
prompt_text = request_output.prompt
|
||||
assert prompt_text is not None
|
||||
text_outputs = [prompt_text + output.text for output in request_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||
|
||||
|
|
@ -233,27 +253,33 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt or engine.tokenizer.decode(
|
||||
prompt_text = final_output.prompt or engine.tokenizer.decode(
|
||||
final_output.prompt_token_ids
|
||||
)
|
||||
assert prompt is not None
|
||||
assert prompt_text is not None
|
||||
text_outputs = [output.text for output in final_output.outputs]
|
||||
finish_reasons = [output.finish_reason for output in final_output.outputs]
|
||||
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
|
||||
ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons}
|
||||
|
||||
# Include logprobs if requested (useful for RL training)
|
||||
# Format matches what atroposlib's VLLMServer expects
|
||||
if sampling_params.logprobs is not None:
|
||||
output_logprobs = [
|
||||
[
|
||||
[{key: value.logprob for key, value in logprob.items()}]
|
||||
for logprob in x.logprobs
|
||||
]
|
||||
for x in final_output.outputs
|
||||
]
|
||||
output_logprobs = []
|
||||
for x in final_output.outputs:
|
||||
if x.logprobs:
|
||||
# Format: [[{token_id: logprob}, ...], ...] per output
|
||||
seq_logprobs = [
|
||||
[{str(key): value.logprob for key, value in logprob.items()}]
|
||||
for logprob in x.logprobs
|
||||
]
|
||||
else:
|
||||
seq_logprobs = []
|
||||
output_logprobs.append(seq_logprobs)
|
||||
|
||||
prompt_token_ids = final_output.prompt_token_ids
|
||||
output_token_ids = [x.token_ids for x in final_output.outputs]
|
||||
output_token_ids = [list(x.token_ids) for x in final_output.outputs]
|
||||
ret["logprobs"] = output_logprobs
|
||||
ret["prompt_token_ids"] = prompt_token_ids
|
||||
ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else []
|
||||
ret["token_ids"] = output_token_ids
|
||||
|
||||
return JSONResponse(ret)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue