update openai/trl_vllm server with new fn

This commit is contained in:
Dakota 2025-10-16 13:45:55 -05:00
parent c36ec29656
commit 134cbc09d0
3 changed files with 24 additions and 9 deletions

View file

@ -13,6 +13,7 @@ from openai.types.chat.chat_completion import (
ChatCompletionMessage,
Choice,
)
from openai.types.completion import Completion, CompletionChoice
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
@ -81,7 +82,7 @@ class TrlVllmServer(APIServer):
)
return completions
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion using the trl's vLLM server.
"""
@ -102,25 +103,31 @@ class TrlVllmServer(APIServer):
},
) as response:
completions = await response.json()
completions = ChatCompletion(
completions = Completion(
id=str(uuid.uuid4()),
object="chat.completion",
object="text_completion",
created=int(time.time()),
model=self.config.model_name,
choices=[
Choice(
CompletionChoice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
else "length"
),
index=i,
message=ChatCompletionMessage(
content=self.tokenizer.decode(completion),
role="assistant",
),
text=self.tokenizer.decode(completion),
)
for i, completion in enumerate(completions["completion_ids"])
],
)
return completions
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for the tokens and logprobs completion using the openai client.
"""
raise NotImplementedError("Not implemented for trl's vLLM server yet.")