managed_Server pass through and centralize sem logic

This commit is contained in:
Jai Suphavadeeprasit 2026-03-05 15:46:33 -05:00
parent c85a3e5ee7
commit b91922082e
4 changed files with 208 additions and 26 deletions

View file

@ -260,7 +260,7 @@ class VLLMServer(APIServer):
return [], []
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
@ -315,26 +315,19 @@ class VLLMServer(APIServer):
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
# Keep semaphore behavior consistent with other server calls.
split = request_data.pop("split", "train")
sem = self.sem if split == "train" else self.eval_sem
while not self.server_healthy:
await asyncio.sleep(1)
async with sem:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None: