This commit is contained in:
Cavit Erginsoy 2025-02-03 11:35:30 +00:00
parent 1e27021e11
commit 6c564b3dd9
13 changed files with 305 additions and 317 deletions

View file

@ -11,29 +11,30 @@ from tqdm import tqdm
import reasoning_gym
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
"""
Check for duplicate word pairs in a word ladder JSONL file.
Returns:
tuple[bool, dict]: (has_duplicates, valid_entries) where:
- has_duplicates: True if any duplicates were found
- valid_entries: Dict mapping line_number -> data for non-duplicate entries
Note: A pair is considered duplicate if either (word1, word2) or (word2, word1)
already exists, since word ladder paths are bidirectional.
"""
pairs_seen = {} # (start, end) -> (line_number, data)
valid_entries = {}
duplicates_found = False
with open(jsonl_path, 'r', encoding='utf-8') as f:
with open(jsonl_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f):
data = json.loads(line)
metadata = data['metadata']
pair = (metadata['start_word'], metadata['end_word'])
reverse_pair = (metadata['end_word'], metadata['start_word'])
metadata = data["metadata"]
pair = (metadata["start_word"], metadata["end_word"])
reverse_pair = (metadata["end_word"], metadata["start_word"])
# Check both orientations of the pair
if pair in pairs_seen or reverse_pair in pairs_seen:
duplicates_found = True
@ -43,9 +44,10 @@ def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
# Store both the line number and data for valid entries
pairs_seen[pair] = (line_num, data)
valid_entries[line_num] = data
return duplicates_found, valid_entries
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
"""
Creates a word ladder dataset and writes each sample as a JSON line.
@ -65,67 +67,67 @@ def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> N
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
else:
jsonl_path = Path(jsonl_path)
target_size = config['dataset_config']['size']
target_size = config["dataset_config"]["size"]
current_size = 0
max_attempts = 3 # Limit total regeneration attempts
attempt = 0
# Initial generation
dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
with open(jsonl_path, 'w', encoding='utf-8') as f:
dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
with open(jsonl_path, "w", encoding="utf-8") as f:
for item in tqdm(dataset, desc="Generating initial ladder examples"):
row = {
'question': item['question'],
'answer': item['answer'],
'reasoning': None,
'metadata': item.get('metadata', {})
"question": item["question"],
"answer": item["answer"],
"reasoning": None,
"metadata": item.get("metadata", {}),
}
f.write(json.dumps(row) + '\n')
f.write(json.dumps(row) + "\n")
while attempt < max_attempts:
# Check entire file for duplicates
has_duplicates, valid_entries = check_duplicates(jsonl_path)
current_size = len(valid_entries)
if not has_duplicates and current_size == target_size:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
return
# If we have duplicates or not enough entries, regenerate the missing amount
needed = target_size - current_size
if needed > 0:
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
# Generate additional examples
config['dataset_config']['size'] = needed
additional_dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
config["dataset_config"]["size"] = needed
additional_dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
# Write all entries to a temporary file
temp_path = jsonl_path.with_suffix('.tmp')
with open(temp_path, 'w', encoding='utf-8') as f:
temp_path = jsonl_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
# Write existing valid entries
for data in valid_entries.values():
f.write(json.dumps(data) + '\n')
f.write(json.dumps(data) + "\n")
# Write new entries
for item in additional_dataset:
row = {
'question': item['question'],
'answer': item['answer'],
'reasoning': None,
'metadata': item.get('metadata', {})
"question": item["question"],
"answer": item["answer"],
"reasoning": None,
"metadata": item.get("metadata", {}),
}
f.write(json.dumps(row) + '\n')
f.write(json.dumps(row) + "\n")
# Replace original file with temporary file
temp_path.replace(jsonl_path)
# Note: We'll check for duplicates again at the start of the next loop
attempt += 1
if current_size < target_size:
print(f"\nWarning: Could only generate {current_size} unique examples after {max_attempts} attempts.")
else:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
print(f"\nSuccessfully created dataset with {current_size} unique examples.")

View file

@ -14,17 +14,16 @@ In our informal testing, Sonnet was deemed best performance value.
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
"""
import os
import json
import uuid
import os
import time
import uuid
from pathlib import Path
from tqdm import tqdm
import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
from tqdm import tqdm
# Updated default output directory to use the parent directory.
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
@ -36,18 +35,19 @@ BATCH_SIZE = 2500
COMMON_UUID = uuid.uuid4().hex[:8]
# Set up the Anthropic client (ensure the API key is set in the environment)
client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY'])
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def submit_reasoning_batches(
input_path: str = DEFAULT_INPUT_JSONL,
batch_metadata_prefix: str = "batch_metadata",
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
) -> None:
"""
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
Anthropic's Message Batches API.
Args:
input_path: Path to input JSONL file
batch_metadata_prefix: Prefix for batch metadata files
@ -59,34 +59,36 @@ def submit_reasoning_batches(
# Read the system prompt from file (used as a preamble for every request)
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
system_message = [{
"type": "text",
"text": sys_file.read(),
"cache_control": {"type": "ephemeral"} # Enable anthropic prompt caching
}]
system_message = [
{
"type": "text",
"text": sys_file.read(),
"cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
}
]
batch_requests = []
custom_ids = [] # List of custom_ids for the current batch
batch_num = 0
# Get the total number of lines in advance for tqdm progress bar.
total_lines = sum(1 for _ in open(input_path))
with open(input_path, 'r', encoding="utf-8") as infile:
with open(input_path, "r", encoding="utf-8") as infile:
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
data = json.loads(line)
# Skip example if 'reasoning' already exists.
if not data.get('reasoning'):
if not data.get("reasoning"):
# Build a custom id. Here we use the row position and the start/end words:
metadata = data.get("metadata", {})
start = metadata.get("start_word", "unknown")
end = metadata.get("end_word", "unknown")
custom_id = f"{start}_{end}_{idx}"
custom_ids.append(custom_id)
# Build the prompt text exactly as before.
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
# Build the request payload using Request and MessageCreateParamsNonStreaming.
request_payload = Request(
custom_id=custom_id,
@ -95,10 +97,8 @@ def submit_reasoning_batches(
max_tokens=8192,
temperature=0.5,
system=system_message,
messages=[
{"role": "user", "content": prompt}
]
)
messages=[{"role": "user", "content": prompt}],
),
)
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
if isinstance(request_payload, dict):
@ -106,7 +106,7 @@ def submit_reasoning_batches(
else:
request_payload.custom_id = custom_id
batch_requests.append(request_payload)
# If we have reached our batch size limit, submit the current batch.
if len(batch_requests) >= BATCH_SIZE:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
@ -114,7 +114,7 @@ def submit_reasoning_batches(
# Reset for the next batch
batch_requests = []
custom_ids = []
# Submit any remaining requests that didn't complete a full batch.
if batch_requests:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
@ -141,11 +141,11 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
iso_str = iso_str.replace("+00:00", "Z")
return iso_str
def extract_custom_id(req):
# Safely extract the custom_id attribute whether req is an object or a dict.
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
max_attempts = 2
attempt = 0
last_exception = None
@ -153,9 +153,7 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
while attempt < max_attempts:
try:
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
message_batch = client.messages.batches.create(
requests=batch_requests
)
message_batch = client.messages.batches.create(requests=batch_requests)
time.sleep(1)
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
break # Success: exit the loop.
@ -166,17 +164,21 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
if attempt < max_attempts:
print("Retrying...")
time.sleep(1)
if message_batch is None:
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
error_msg = f"{str(last_exception)} after {max_attempts} attempts" if last_exception else f"Failed after {max_attempts} attempts"
error_msg = (
f"{str(last_exception)} after {max_attempts} attempts"
if last_exception
else f"Failed after {max_attempts} attempts"
)
failed_info = {
"batch_number": batch_num,
"error": error_msg,
"batch_requests": [extract_custom_id(req) for req in batch_requests],
"input_file": input_path,
}
with open(error_filename, 'a', encoding='utf-8') as error_file:
with open(error_filename, "a", encoding="utf-8") as error_file:
error_file.write(json.dumps(failed_info) + "\n")
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
return
@ -198,14 +200,15 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
"batch_id": message_batch.id,
"api_response": api_response,
"custom_ids": custom_ids,
"input_file": os.path.basename(input_path),
"input_file": os.path.basename(input_path),
}
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
with open(metadata_filename, 'a', encoding='utf-8') as meta_file:
with open(metadata_filename, "a", encoding="utf-8") as meta_file:
meta_file.write(json.dumps(batch_metadata) + "\n")
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
if __name__ == "__main__":
# When running this module directly, submit the reasoning batches.
submit_reasoning_batches()

View file

@ -9,7 +9,7 @@ It then calculates and prints statistics for each usage token field:
- cache_creation_input_tokens
- cache_read_input_tokens
- output_tokens
+pricing calculations
+pricing calculations
+calculates the savings from caching (vs if we hadn't done any caching)
+forecasts costs for 10,000, 20,000 and 50,000 jobs based on tokens per query
@ -17,17 +17,14 @@ Usage:
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
"""
import json
import argparse
import json
from statistics import mean
def main():
parser = argparse.ArgumentParser(
description="Compute usage token statistics from a JSONL file."
)
parser.add_argument(
"file", help="Path to the JSONL file containing usage token data."
)
parser = argparse.ArgumentParser(description="Compute usage token statistics from a JSONL file.")
parser.add_argument("file", help="Path to the JSONL file containing usage token data.")
args = parser.parse_args()
# Usage token fields that we want to track
@ -43,7 +40,7 @@ def main():
pricing = {
"input_tokens": base_input_rate,
"cache_creation_input_tokens": base_input_rate * 1.25, # More expensive for initial computation
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
"output_tokens": 7.50,
}
@ -82,7 +79,7 @@ def main():
print(f"\nProcessed {total_lines} lines with {error_count} error(s).\n")
print("Usage Tokens Statistics:")
print("-" * 40)
grand_total_cost = 0.0
# Calculate and print stats for each token type
for key in usage_fields:
@ -115,7 +112,7 @@ def main():
# Without caching, all tokens would have been charged at the standard input rate.
#
# Baseline cost (if no caching were used):
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
# / 1_000_000 * base_input_rate
#
# Actual cost (with caching):
@ -129,9 +126,11 @@ def main():
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
actual_input_cost = (sum_input) / 1_000_000 * pricing["input_tokens"] \
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
actual_input_cost = (
(sum_input) / 1_000_000 * pricing["input_tokens"]
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
caching_savings = baseline_input_cost - actual_input_cost
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
@ -172,12 +171,16 @@ def main():
forecast_output = avg_output_tokens * jobs
# Forecast actual cost (with caching applied for input tokens):
actual_input_cost_forecast = (forecast_input) / 1_000_000 * pricing["input_tokens"] \
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
actual_input_cost_forecast = (
(forecast_input) / 1_000_000 * pricing["input_tokens"]
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
# Without caching, all input-related tokens would be at base_input_rate:
baseline_input_cost_forecast = (forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
baseline_input_cost_forecast = (
(forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
)
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
@ -198,5 +201,6 @@ def main():
else:
print("No valid jobs to forecast future costs.")
if __name__ == "__main__":
main()
main()