update openai/trl_vllm server with new fn

This commit is contained in:
Dakota 2025-10-16 13:45:55 -05:00
parent c36ec29656
commit 134cbc09d0
3 changed files with 24 additions and 9 deletions

View file

@ -134,6 +134,14 @@ class OpenAIServer(APIServer):
completions.choices.extend(c.choices)
return completions
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for the tokens and logprobs completion using the openai client.
"""
raise NotImplementedError("Tokens and logprobs not supported by base OpenAI API, use specific API servers.")
def resolve_openai_configs(
default_server_configs,

View file

@ -173,7 +173,7 @@ class SGLangServer(APIServer):
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url}/generate",
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {},
timeout=aiohttp.ClientTimeout(total=self.config.timeout)

View file

@ -13,6 +13,7 @@ from openai.types.chat.chat_completion import (
ChatCompletionMessage,
Choice,
)
from openai.types.completion import Completion, CompletionChoice
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
@ -81,7 +82,7 @@ class TrlVllmServer(APIServer):
)
return completions
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion using the trl's vLLM server.
"""
@ -102,25 +103,31 @@ class TrlVllmServer(APIServer):
},
) as response:
completions = await response.json()
completions = ChatCompletion(
completions = Completion(
id=str(uuid.uuid4()),
object="chat.completion",
object="text_completion",
created=int(time.time()),
model=self.config.model_name,
choices=[
Choice(
CompletionChoice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
else "length"
),
index=i,
message=ChatCompletionMessage(
content=self.tokenizer.decode(completion),
role="assistant",
),
text=self.tokenizer.decode(completion),
)
for i, completion in enumerate(completions["completion_ids"])
],
)
return completions
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for the tokens and logprobs completion using the openai client.
"""
raise NotImplementedError("Not implemented for trl's vLLM server yet.")