first commit

This commit is contained in:
Dakota Nous 2025-04-29 12:10:10 -07:00
commit 621d00dd80
89 changed files with 15315 additions and 0 deletions

View file

@ -0,0 +1,112 @@
import math
import random
# TODO: move this to the server manager
async def generate_with_diverse_first_tokens(
self, messages, prefill="", n=8, max_tokens=4096, temperature=1.0
):
"""
Generate diverse completions by sampling different first tokens.
Parameters:
- messages: List of message dictionaries for chat completion
- prefill: Prefix text to add to assistant's message
- n: Number of diverse completions to generate
- max_tokens: Maximum tokens per completion
- temperature: Sampling temperature
Returns:
- List of completion strings
"""
# Step 1: First get the logprobs for just the first token
first_token_messages = messages + [{"role": "assistant", "content": prefill}]
first_token_completion = await self.server.chat_completion(
messages=first_token_messages,
n=1,
max_tokens=1,
temperature=0.0, # Use 0 temperature to get raw logprobs
logprobs=True,
top_logprobs=20, # Get top 20 logprobs for the first token
)
# Extract logprobs from the completion
try:
# Get the logprobs for the first token
logprobs_dict = (
first_token_completion.choices[0].logprobs.content[0].top_logprobs
)
# Convert to list of (token, logprob) tuples
logprobs_list = [(item.token, item.logprob) for item in logprobs_dict]
# Convert logprobs to probabilities with temperature
logprobs_array = [lp for _, lp in logprobs_list]
probs = [math.exp(lp / temperature) for lp in logprobs_array]
total = sum(probs)
probs = [p / total for p in probs]
# Sample n unique tokens
sampled_indices = random.choices(
range(len(logprobs_list)), weights=probs, k=min(n, len(logprobs_list))
)
# Ensure unique indices
sampled_indices = list(set(sampled_indices))
# If we have fewer than n tokens, sample again to fill
while len(sampled_indices) < n and len(sampled_indices) < len(logprobs_list):
remaining = min(
n - len(sampled_indices), len(logprobs_list) - len(sampled_indices)
)
available_indices = [
i for i in range(len(logprobs_list)) if i not in sampled_indices
]
available_probs = [probs[i] for i in available_indices]
total = sum(available_probs)
if total > 0:
available_probs = [p / total for p in available_probs]
additional_indices = random.choices(
available_indices, weights=available_probs, k=remaining
)
sampled_indices.extend(additional_indices)
else:
# If all remaining probs are 0, just pick randomly
additional_indices = random.sample(available_indices, k=remaining)
sampled_indices.extend(additional_indices)
# Get the selected first tokens
first_tokens = [logprobs_list[i][0] for i in sampled_indices]
except (AttributeError, IndexError, KeyError) as e:
# Fallback if we can't extract logprobs properly
print(f"Error extracting logprobs: {e}")
return await self.fallback_generate(
messages, prefill, n, max_tokens, temperature
)
# Step 2: Generate completions with each selected first token
completions = []
for token in first_tokens:
# Create a prompt with the first token already included
prompt_with_token = messages + [
{"role": "assistant", "content": prefill + token}
]
# Generate the rest of the completion
completion = await self.server.chat_completion(
messages=prompt_with_token,
n=1,
max_tokens=max_tokens - 1, # Subtract 1 for the token we already used
temperature=temperature,
top_p=0.3,
extra_body={
"min_p": 0.5,
"repetition_penalty": 1.05,
},
)
# Extract the completion content and remove the prefill+token
full_content = completion.choices[0].message.content
completions.append(token + full_content)