mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
112
atroposlib/utils/force_diverse_samples.py
Normal file
112
atroposlib/utils/force_diverse_samples.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import math
|
||||
import random
|
||||
|
||||
|
||||
# TODO: move this to the server manager
|
||||
async def generate_with_diverse_first_tokens(
|
||||
self, messages, prefill="", n=8, max_tokens=4096, temperature=1.0
|
||||
):
|
||||
"""
|
||||
Generate diverse completions by sampling different first tokens.
|
||||
|
||||
Parameters:
|
||||
- messages: List of message dictionaries for chat completion
|
||||
- prefill: Prefix text to add to assistant's message
|
||||
- n: Number of diverse completions to generate
|
||||
- max_tokens: Maximum tokens per completion
|
||||
- temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
- List of completion strings
|
||||
"""
|
||||
# Step 1: First get the logprobs for just the first token
|
||||
first_token_messages = messages + [{"role": "assistant", "content": prefill}]
|
||||
|
||||
first_token_completion = await self.server.chat_completion(
|
||||
messages=first_token_messages,
|
||||
n=1,
|
||||
max_tokens=1,
|
||||
temperature=0.0, # Use 0 temperature to get raw logprobs
|
||||
logprobs=True,
|
||||
top_logprobs=20, # Get top 20 logprobs for the first token
|
||||
)
|
||||
|
||||
# Extract logprobs from the completion
|
||||
try:
|
||||
# Get the logprobs for the first token
|
||||
logprobs_dict = (
|
||||
first_token_completion.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
|
||||
# Convert to list of (token, logprob) tuples
|
||||
logprobs_list = [(item.token, item.logprob) for item in logprobs_dict]
|
||||
|
||||
# Convert logprobs to probabilities with temperature
|
||||
logprobs_array = [lp for _, lp in logprobs_list]
|
||||
probs = [math.exp(lp / temperature) for lp in logprobs_array]
|
||||
total = sum(probs)
|
||||
probs = [p / total for p in probs]
|
||||
|
||||
# Sample n unique tokens
|
||||
sampled_indices = random.choices(
|
||||
range(len(logprobs_list)), weights=probs, k=min(n, len(logprobs_list))
|
||||
)
|
||||
|
||||
# Ensure unique indices
|
||||
sampled_indices = list(set(sampled_indices))
|
||||
|
||||
# If we have fewer than n tokens, sample again to fill
|
||||
while len(sampled_indices) < n and len(sampled_indices) < len(logprobs_list):
|
||||
remaining = min(
|
||||
n - len(sampled_indices), len(logprobs_list) - len(sampled_indices)
|
||||
)
|
||||
available_indices = [
|
||||
i for i in range(len(logprobs_list)) if i not in sampled_indices
|
||||
]
|
||||
available_probs = [probs[i] for i in available_indices]
|
||||
total = sum(available_probs)
|
||||
if total > 0:
|
||||
available_probs = [p / total for p in available_probs]
|
||||
additional_indices = random.choices(
|
||||
available_indices, weights=available_probs, k=remaining
|
||||
)
|
||||
sampled_indices.extend(additional_indices)
|
||||
else:
|
||||
# If all remaining probs are 0, just pick randomly
|
||||
additional_indices = random.sample(available_indices, k=remaining)
|
||||
sampled_indices.extend(additional_indices)
|
||||
|
||||
# Get the selected first tokens
|
||||
first_tokens = [logprobs_list[i][0] for i in sampled_indices]
|
||||
|
||||
except (AttributeError, IndexError, KeyError) as e:
|
||||
# Fallback if we can't extract logprobs properly
|
||||
print(f"Error extracting logprobs: {e}")
|
||||
return await self.fallback_generate(
|
||||
messages, prefill, n, max_tokens, temperature
|
||||
)
|
||||
|
||||
# Step 2: Generate completions with each selected first token
|
||||
completions = []
|
||||
for token in first_tokens:
|
||||
# Create a prompt with the first token already included
|
||||
prompt_with_token = messages + [
|
||||
{"role": "assistant", "content": prefill + token}
|
||||
]
|
||||
|
||||
# Generate the rest of the completion
|
||||
completion = await self.server.chat_completion(
|
||||
messages=prompt_with_token,
|
||||
n=1,
|
||||
max_tokens=max_tokens - 1, # Subtract 1 for the token we already used
|
||||
temperature=temperature,
|
||||
top_p=0.3,
|
||||
extra_body={
|
||||
"min_p": 0.5,
|
||||
"repetition_penalty": 1.05,
|
||||
},
|
||||
)
|
||||
|
||||
# Extract the completion content and remove the prefill+token
|
||||
full_content = completion.choices[0].message.content
|
||||
completions.append(token + full_content)
|
||||
Loading…
Add table
Add a link
Reference in a new issue