Merge remote-tracking branch 'origin/add-logprob-server-manager-fn' into add-logprob-server-manager-fn

This commit is contained in:
Dakota 2025-10-16 13:46:03 -05:00
commit d240dbb3b7
3 changed files with 59 additions and 34 deletions

View file

@ -6,10 +6,10 @@ import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic_cli import FailedExecutionException
from transformers import AutoTokenizer
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
from transformers import AutoTokenizer
class SGLangServer(APIServer):
@ -171,12 +171,17 @@ class SGLangServer(APIServer):
# Make async request to SGLang /generate endpoint
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {},
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
@ -194,8 +199,12 @@ class SGLangServer(APIServer):
# Get output logprobs - extract just the logprob values
output_token_logprobs = meta_info.get("output_token_logprobs", [])
logprobs = [item[0] for item in output_token_logprobs] # Extract logprob from (logprob, token_id, text) tuples
output_ids = [item[1] for item in output_token_logprobs] # Extract token ID from (logprob, token_id, text) tuples
logprobs = [
item[0] for item in output_token_logprobs
] # Extract logprob from (logprob, token_id, text) tuples
output_ids = [
item[1] for item in output_token_logprobs
] # Extract token ID from (logprob, token_id, text) tuples
# Get finish reason
finish_reason = meta_info.get("finish_reason", None)
@ -204,7 +213,12 @@ class SGLangServer(APIServer):
output_logprobs_list.append(logprobs)
finish_reasons_list.append(finish_reason)
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons_list
return (
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons_list,
)
def resolve_openai_configs(
@ -301,13 +315,15 @@ if __name__ == "__main__":
# Test the tokens_and_logprobs_completion method
print("Testing tokens_and_logprobs_completion...")
try:
prompt_tokens, output_tokens, output_logprobs, finish_reasons = await server.tokens_and_logprobs_completion(
prompt="The capital of France is",
n=4,
max_tokens=32,
temperature=1.0,
top_p=1.0,
stop=["User:", "Human:", "Assistant:", "</answer>"],
prompt_tokens, output_tokens, output_logprobs, finish_reasons = (
await server.tokens_and_logprobs_completion(
prompt="The capital of France is",
n=4,
max_tokens=32,
temperature=1.0,
top_p=1.0,
stop=["User:", "Human:", "Assistant:", "</answer>"],
)
)
print(f"\nResults:")
@ -317,12 +333,15 @@ if __name__ == "__main__":
print(f"Finish reasons: {finish_reasons}")
print(f"\nNumber of completions: {len(output_tokens)}")
print(f"Output length: {[len(tokens) for tokens in output_tokens]}")
responses = '\n\n'.join([server.tokenizer.decode(tokens) for tokens in output_tokens])
responses = "\n\n".join(
[server.tokenizer.decode(tokens) for tokens in output_tokens]
)
print(f"Responses:\n-{responses}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
# Run the test