mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge remote-tracking branch 'origin/add-logprob-server-manager-fn' into add-logprob-server-manager-fn
This commit is contained in:
commit
d240dbb3b7
3 changed files with 59 additions and 34 deletions
|
|
@ -6,10 +6,10 @@ import openai
|
|||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic_cli import FailedExecutionException
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class SGLangServer(APIServer):
|
||||
|
|
@ -171,12 +171,17 @@ class SGLangServer(APIServer):
|
|||
|
||||
# Make async request to SGLang /generate endpoint
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {},
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
|
@ -194,8 +199,12 @@ class SGLangServer(APIServer):
|
|||
|
||||
# Get output logprobs - extract just the logprob values
|
||||
output_token_logprobs = meta_info.get("output_token_logprobs", [])
|
||||
logprobs = [item[0] for item in output_token_logprobs] # Extract logprob from (logprob, token_id, text) tuples
|
||||
output_ids = [item[1] for item in output_token_logprobs] # Extract token ID from (logprob, token_id, text) tuples
|
||||
logprobs = [
|
||||
item[0] for item in output_token_logprobs
|
||||
] # Extract logprob from (logprob, token_id, text) tuples
|
||||
output_ids = [
|
||||
item[1] for item in output_token_logprobs
|
||||
] # Extract token ID from (logprob, token_id, text) tuples
|
||||
|
||||
# Get finish reason
|
||||
finish_reason = meta_info.get("finish_reason", None)
|
||||
|
|
@ -204,7 +213,12 @@ class SGLangServer(APIServer):
|
|||
output_logprobs_list.append(logprobs)
|
||||
finish_reasons_list.append(finish_reason)
|
||||
|
||||
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons_list
|
||||
return (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons_list,
|
||||
)
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
|
|
@ -301,13 +315,15 @@ if __name__ == "__main__":
|
|||
# Test the tokens_and_logprobs_completion method
|
||||
print("Testing tokens_and_logprobs_completion...")
|
||||
try:
|
||||
prompt_tokens, output_tokens, output_logprobs, finish_reasons = await server.tokens_and_logprobs_completion(
|
||||
prompt="The capital of France is",
|
||||
n=4,
|
||||
max_tokens=32,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=["User:", "Human:", "Assistant:", "</answer>"],
|
||||
prompt_tokens, output_tokens, output_logprobs, finish_reasons = (
|
||||
await server.tokens_and_logprobs_completion(
|
||||
prompt="The capital of France is",
|
||||
n=4,
|
||||
max_tokens=32,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=["User:", "Human:", "Assistant:", "</answer>"],
|
||||
)
|
||||
)
|
||||
|
||||
print(f"\nResults:")
|
||||
|
|
@ -317,12 +333,15 @@ if __name__ == "__main__":
|
|||
print(f"Finish reasons: {finish_reasons}")
|
||||
print(f"\nNumber of completions: {len(output_tokens)}")
|
||||
print(f"Output length: {[len(tokens) for tokens in output_tokens]}")
|
||||
responses = '\n\n'.join([server.tokenizer.decode(tokens) for tokens in output_tokens])
|
||||
responses = "\n\n".join(
|
||||
[server.tokenizer.decode(tokens) for tokens in output_tokens]
|
||||
)
|
||||
print(f"Responses:\n-{responses}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Run the test
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue