mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
update openai/trl_vllm server with new fn
This commit is contained in:
parent
c36ec29656
commit
134cbc09d0
3 changed files with 24 additions and 9 deletions
|
|
@ -13,6 +13,7 @@ from openai.types.chat.chat_completion import (
|
|||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion, CompletionChoice
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
|
@ -81,7 +82,7 @@ class TrlVllmServer(APIServer):
|
|||
)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the trl's vLLM server.
|
||||
"""
|
||||
|
|
@ -102,25 +103,31 @@ class TrlVllmServer(APIServer):
|
|||
},
|
||||
) as response:
|
||||
completions = await response.json()
|
||||
completions = ChatCompletion(
|
||||
completions = Completion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="chat.completion",
|
||||
object="text_completion",
|
||||
created=int(time.time()),
|
||||
model=self.config.model_name,
|
||||
choices=[
|
||||
Choice(
|
||||
CompletionChoice(
|
||||
finish_reason=(
|
||||
"stop"
|
||||
if self.tokenizer.eos_token_id in completion
|
||||
else "length"
|
||||
),
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=self.tokenizer.decode(completion),
|
||||
role="assistant",
|
||||
),
|
||||
text=self.tokenizer.decode(completion),
|
||||
)
|
||||
for i, completion in enumerate(completions["completion_ids"])
|
||||
],
|
||||
)
|
||||
return completions
|
||||
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for the tokens and logprobs completion using the openai client.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented for trl's vLLM server yet.")
|
||||
Loading…
Add table
Add a link
Reference in a new issue