mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add sglang specific token level logprob handling and server manager/baseline logprob/token fn
This commit is contained in:
parent
4862e9972f
commit
c36ec29656
4 changed files with 512 additions and 37 deletions
|
|
@ -16,6 +16,7 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
)
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
||||
from atroposlib.envs.server_handling.sglang_server import SGLangServer
|
||||
|
||||
|
||||
class ServerManagerConfig(BaseModel):
|
||||
|
|
@ -54,6 +55,8 @@ class ServerManager:
|
|||
server_class = OpenAIServer
|
||||
elif configs.server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
elif configs.server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs.server_type}")
|
||||
else:
|
||||
|
|
@ -61,6 +64,8 @@ class ServerManager:
|
|||
server_class = OpenAIServer
|
||||
elif configs[0].server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
elif configs[0].server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
||||
if testing:
|
||||
|
|
@ -241,6 +246,51 @@ class ServerManager:
|
|||
)
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
results = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
n_to_use = min(total_n, self.max_n_completions)
|
||||
kwargs["n"] = n_to_use
|
||||
results.append(self.tokens_and_logprobs_completion(**kwargs))
|
||||
total_n -= n_to_use
|
||||
results = await asyncio.gather(*results)
|
||||
# Merge results - prompt_tokens should be same, extend output lists
|
||||
prompt_tokens = results[0][0]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
finish_reasons = []
|
||||
for _, out_tokens, out_logprobs, out_finish_reasons in results:
|
||||
output_tokens.extend(out_tokens)
|
||||
output_logprobs.extend(out_logprobs)
|
||||
finish_reasons.extend(out_finish_reasons)
|
||||
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
|
||||
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
) > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(**kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
|
||||
most_available_server = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue