diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index ce77ed59..5db8e97d 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -203,8 +203,28 @@ async def generate(request: Request) -> Response: @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: - prompt = request_dict.pop("prompt") + prompt_input = request_dict.pop("prompt") stream = request_dict.pop("stream", False) + + # Handle both string prompts and {"prompt_token_ids": [...]} format + # The latter is used by atroposlib's VLLMServer + if isinstance(prompt_input, dict) and "prompt_token_ids" in prompt_input: + # Token IDs format from atroposlib + prompt_token_ids = prompt_input["prompt_token_ids"] + prompt = {"prompt_token_ids": prompt_token_ids} + else: + # String prompt + prompt = prompt_input + + # Handle logprobs parameter - atroposlib sends logprobs=0 which means "return logprobs" + # vLLM uses None to mean "don't return logprobs" and an int for "return N top logprobs" + if "logprobs" in request_dict: + logprobs_val = request_dict["logprobs"] + # logprobs=0 means return logprobs (just 1 per token) + # logprobs=None or not present means don't return logprobs + if logprobs_val is not None: + request_dict["logprobs"] = max(1, logprobs_val) # At least 1 + request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -215,9 +235,9 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: # Streaming: yield results as they're generated async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: - prompt = request_output.prompt - assert prompt is not None - text_outputs = [prompt + output.text for output in request_output.outputs] + prompt_text = request_output.prompt + assert prompt_text is not None + text_outputs = [prompt_text + output.text for output in request_output.outputs] ret = {"text": text_outputs} yield (json.dumps(ret) + "\n").encode("utf-8") @@ -233,27 +253,33 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: return Response(status_code=499) assert final_output is not None - prompt = final_output.prompt or engine.tokenizer.decode( + prompt_text = final_output.prompt or engine.tokenizer.decode( final_output.prompt_token_ids ) - assert prompt is not None + assert prompt_text is not None text_outputs = [output.text for output in final_output.outputs] finish_reasons = [output.finish_reason for output in final_output.outputs] - ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons} + ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons} # Include logprobs if requested (useful for RL training) + # Format matches what atroposlib's VLLMServer expects if sampling_params.logprobs is not None: - output_logprobs = [ - [ - [{key: value.logprob for key, value in logprob.items()}] - for logprob in x.logprobs - ] - for x in final_output.outputs - ] + output_logprobs = [] + for x in final_output.outputs: + if x.logprobs: + # Format: [[{token_id: logprob}, ...], ...] per output + seq_logprobs = [ + [{str(key): value.logprob for key, value in logprob.items()}] + for logprob in x.logprobs + ] + else: + seq_logprobs = [] + output_logprobs.append(seq_logprobs) + prompt_token_ids = final_output.prompt_token_ids - output_token_ids = [x.token_ids for x in final_output.outputs] + output_token_ids = [list(x.token_ids) for x in final_output.outputs] ret["logprobs"] = output_logprobs - ret["prompt_token_ids"] = prompt_token_ids + ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else [] ret["token_ids"] = output_token_ids return JSONResponse(ret)