mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
112 lines
4.1 KiB
Python
112 lines
4.1 KiB
Python
import math
|
|
import random
|
|
|
|
|
|
# TODO: move this to the server manager
|
|
async def generate_with_diverse_first_tokens(
|
|
self, messages, prefill="", n=8, max_tokens=4096, temperature=1.0
|
|
):
|
|
"""
|
|
Generate diverse completions by sampling different first tokens.
|
|
|
|
Parameters:
|
|
- messages: List of message dictionaries for chat completion
|
|
- prefill: Prefix text to add to assistant's message
|
|
- n: Number of diverse completions to generate
|
|
- max_tokens: Maximum tokens per completion
|
|
- temperature: Sampling temperature
|
|
|
|
Returns:
|
|
- List of completion strings
|
|
"""
|
|
# Step 1: First get the logprobs for just the first token
|
|
first_token_messages = messages + [{"role": "assistant", "content": prefill}]
|
|
|
|
first_token_completion = await self.server.chat_completion(
|
|
messages=first_token_messages,
|
|
n=1,
|
|
max_tokens=1,
|
|
temperature=0.0, # Use 0 temperature to get raw logprobs
|
|
logprobs=True,
|
|
top_logprobs=20, # Get top 20 logprobs for the first token
|
|
)
|
|
|
|
# Extract logprobs from the completion
|
|
try:
|
|
# Get the logprobs for the first token
|
|
logprobs_dict = (
|
|
first_token_completion.choices[0].logprobs.content[0].top_logprobs
|
|
)
|
|
|
|
# Convert to list of (token, logprob) tuples
|
|
logprobs_list = [(item.token, item.logprob) for item in logprobs_dict]
|
|
|
|
# Convert logprobs to probabilities with temperature
|
|
logprobs_array = [lp for _, lp in logprobs_list]
|
|
probs = [math.exp(lp / temperature) for lp in logprobs_array]
|
|
total = sum(probs)
|
|
probs = [p / total for p in probs]
|
|
|
|
# Sample n unique tokens
|
|
sampled_indices = random.choices(
|
|
range(len(logprobs_list)), weights=probs, k=min(n, len(logprobs_list))
|
|
)
|
|
|
|
# Ensure unique indices
|
|
sampled_indices = list(set(sampled_indices))
|
|
|
|
# If we have fewer than n tokens, sample again to fill
|
|
while len(sampled_indices) < n and len(sampled_indices) < len(logprobs_list):
|
|
remaining = min(
|
|
n - len(sampled_indices), len(logprobs_list) - len(sampled_indices)
|
|
)
|
|
available_indices = [
|
|
i for i in range(len(logprobs_list)) if i not in sampled_indices
|
|
]
|
|
available_probs = [probs[i] for i in available_indices]
|
|
total = sum(available_probs)
|
|
if total > 0:
|
|
available_probs = [p / total for p in available_probs]
|
|
additional_indices = random.choices(
|
|
available_indices, weights=available_probs, k=remaining
|
|
)
|
|
sampled_indices.extend(additional_indices)
|
|
else:
|
|
# If all remaining probs are 0, just pick randomly
|
|
additional_indices = random.sample(available_indices, k=remaining)
|
|
sampled_indices.extend(additional_indices)
|
|
|
|
# Get the selected first tokens
|
|
first_tokens = [logprobs_list[i][0] for i in sampled_indices]
|
|
|
|
except (AttributeError, IndexError, KeyError) as e:
|
|
# Fallback if we can't extract logprobs properly
|
|
print(f"Error extracting logprobs: {e}")
|
|
return await self.fallback_generate(
|
|
messages, prefill, n, max_tokens, temperature
|
|
)
|
|
|
|
# Step 2: Generate completions with each selected first token
|
|
completions = []
|
|
for token in first_tokens:
|
|
# Create a prompt with the first token already included
|
|
prompt_with_token = messages + [
|
|
{"role": "assistant", "content": prefill + token}
|
|
]
|
|
|
|
# Generate the rest of the completion
|
|
completion = await self.server.chat_completion(
|
|
messages=prompt_with_token,
|
|
n=1,
|
|
max_tokens=max_tokens - 1, # Subtract 1 for the token we already used
|
|
temperature=temperature,
|
|
top_p=0.3,
|
|
extra_body={
|
|
"min_p": 0.5,
|
|
"repetition_penalty": 1.05,
|
|
},
|
|
)
|
|
|
|
# Extract the completion content and remove the prefill+token
|
|
full_content = completion.choices[0].message.content
|
|
completions.append(token + full_content)
|