Completed: full example suite

This commit is contained in:
Cavit Erginsoy 2025-02-03 07:21:12 +00:00
parent c0a16d7f2b
commit de7d37f14f
13 changed files with 1309 additions and 220 deletions

View file

View file

@ -0,0 +1,131 @@
"""
create_word_ladders.py Generates the word ladder dataset as a JSONL file.
Each line is a JSON object with keys: question, answer, metadata, and reasoning (set to None).
"""
import json
import uuid
from pathlib import Path
from tqdm import tqdm
import reasoning_gym
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
"""
Check for duplicate word pairs in a word ladder JSONL file.
Returns:
tuple[bool, dict]: (has_duplicates, valid_entries) where:
- has_duplicates: True if any duplicates were found
- valid_entries: Dict mapping line_number -> data for non-duplicate entries
Note: A pair is considered duplicate if either (word1, word2) or (word2, word1)
already exists, since word ladder paths are bidirectional.
"""
pairs_seen = {} # (start, end) -> (line_number, data)
valid_entries = {}
duplicates_found = False
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
data = json.loads(line)
metadata = data['metadata']
pair = (metadata['start_word'], metadata['end_word'])
reverse_pair = (metadata['end_word'], metadata['start_word'])
# Check both orientations of the pair
if pair in pairs_seen or reverse_pair in pairs_seen:
duplicates_found = True
# Skip this entry - it's a duplicate
continue
else:
# Store both the line number and data for valid entries
pairs_seen[pair] = (line_num, data)
valid_entries[line_num] = data
return duplicates_found, valid_entries
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
"""
Creates a word ladder dataset and writes each sample as a JSON line.
Ensures no duplicate word pairs by regenerating as needed.
"""
if config is None:
raise ValueError("Configuration (config) must be provided.")
# Create output directory if it doesn't exist.
# Updated path to point to the parent folder's output.
output_dir = Path(__file__).resolve().parent.parent / "output"
output_dir.mkdir(exist_ok=True)
# Determine the file path based on uuid when not provided
if jsonl_path is None:
unique_id = uuid.uuid4().hex[:8]
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
else:
jsonl_path = Path(jsonl_path)
target_size = config['dataset_config']['size']
current_size = 0
max_attempts = 3 # Limit total regeneration attempts
attempt = 0
# Initial generation
dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in tqdm(dataset, desc="Generating initial ladder examples"):
row = {
'question': item['question'],
'answer': item['answer'],
'reasoning': None,
'metadata': item.get('metadata', {})
}
f.write(json.dumps(row) + '\n')
while attempt < max_attempts:
# Check entire file for duplicates
has_duplicates, valid_entries = check_duplicates(jsonl_path)
current_size = len(valid_entries)
if not has_duplicates and current_size == target_size:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
return
# If we have duplicates or not enough entries, regenerate the missing amount
needed = target_size - current_size
if needed > 0:
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
# Generate additional examples
config['dataset_config']['size'] = needed
additional_dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
# Write all entries to a temporary file
temp_path = jsonl_path.with_suffix('.tmp')
with open(temp_path, 'w', encoding='utf-8') as f:
# Write existing valid entries
for data in valid_entries.values():
f.write(json.dumps(data) + '\n')
# Write new entries
for item in additional_dataset:
row = {
'question': item['question'],
'answer': item['answer'],
'reasoning': None,
'metadata': item.get('metadata', {})
}
f.write(json.dumps(row) + '\n')
# Replace original file with temporary file
temp_path.replace(jsonl_path)
# Note: We'll check for duplicates again at the start of the next loop
attempt += 1
if current_size < target_size:
print(f"\nWarning: Could only generate {current_size} unique examples after {max_attempts} attempts.")
else:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")

View file

@ -0,0 +1,211 @@
"""
generate_reasoning.py Reads the JSONL file containing ladder examples,
creates batch requests of chain-of-thought prompts split into batches of 2,500,
calls Anthropic's Message Batches API for each batch, and writes separate batch metadata
files for later retrieval of the responses.
*** WARNING ***: Running large batches of requests via the Anthropic API (especially in generate_reasoning.py)
can incur significant costs in Anthropic credits. Please review and understand your API quota and budgeting
before running the API call. If you are testing or working with a demo dataset, adjust the batch size or dataset
size appropriately to avoid unexpected charges.
Using Anthropic's Message Batches API with caching enabled for system prompt.
In our informal testing, Sonnet was deemed best performance value.
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
"""
import os
import json
import uuid
import time
from pathlib import Path
from tqdm import tqdm
import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
# Updated default output directory to use the parent directory.
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
# Add default constants at the top with other constants
DEFAULT_INPUT_JSONL = "output/word_ladder_examples.jsonl"
DEFAULT_SYSTEM_PROMPT = Path(__file__).resolve().parent.parent / "system_prompt.txt"
BATCH_SIZE = 2500
COMMON_UUID = uuid.uuid4().hex[:8]
# Set up the Anthropic client (ensure the API key is set in the environment)
client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY'])
def submit_reasoning_batches(
input_path: str = DEFAULT_INPUT_JSONL,
batch_metadata_prefix: str = "batch_metadata",
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT
) -> None:
"""
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
Anthropic's Message Batches API.
Args:
input_path: Path to input JSONL file
batch_metadata_prefix: Prefix for batch metadata files
system_prompt_path: Path to system prompt file
"""
# Create output directory if it doesn't exist
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
# Read the system prompt from file (used as a preamble for every request)
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
system_message = [{
"type": "text",
"text": sys_file.read(),
"cache_control": {"type": "ephemeral"} # Enable anthropic prompt caching
}]
batch_requests = []
custom_ids = [] # List of custom_ids for the current batch
batch_num = 0
# Get the total number of lines in advance for tqdm progress bar.
total_lines = sum(1 for _ in open(input_path))
with open(input_path, 'r', encoding="utf-8") as infile:
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
data = json.loads(line)
# Skip example if 'reasoning' already exists.
if not data.get('reasoning'):
# Build a custom id. Here we use the row position and the start/end words:
metadata = data.get("metadata", {})
start = metadata.get("start_word", "unknown")
end = metadata.get("end_word", "unknown")
custom_id = f"{start}_{end}_{idx}"
custom_ids.append(custom_id)
# Build the prompt text exactly as before.
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
# Build the request payload using Request and MessageCreateParamsNonStreaming.
request_payload = Request(
custom_id=custom_id,
params=MessageCreateParamsNonStreaming(
model="claude-3-5-sonnet-20241022", # Or choose the appropriate model version
max_tokens=8192,
temperature=0.5,
system=system_message,
messages=[
{"role": "user", "content": prompt}
]
)
)
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
if isinstance(request_payload, dict):
request_payload["custom_id"] = custom_id
else:
request_payload.custom_id = custom_id
batch_requests.append(request_payload)
# If we have reached our batch size limit, submit the current batch.
if len(batch_requests) >= BATCH_SIZE:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
batch_num += 1
# Reset for the next batch
batch_requests = []
custom_ids = []
# Submit any remaining requests that didn't complete a full batch.
if batch_requests:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path):
"""
Helper function to submit a single batch request, track the full API response,
and write out the corresponding metadata including the list of custom_ids.
"""
# Use the default output directory
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
def serialize_datetime(dt):
"""
Convert a datetime object to ISO formatted string.
If dt is None, returns None.
"""
if dt is None:
return None
iso_str = dt.isoformat() # e.g. "2024-08-20T18:37:24.100435+00:00"
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
iso_str = iso_str.replace("+00:00", "Z")
return iso_str
def extract_custom_id(req):
# Safely extract the custom_id attribute whether req is an object or a dict.
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
max_attempts = 2
attempt = 0
last_exception = None
message_batch = None
while attempt < max_attempts:
try:
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
message_batch = client.messages.batches.create(
requests=batch_requests
)
time.sleep(1)
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
break # Success: exit the loop.
except Exception as e:
last_exception = e
attempt += 1
print(f"Error submitting batch {batch_num} on attempt {attempt}: {e}")
if attempt < max_attempts:
print("Retrying...")
time.sleep(1)
if message_batch is None:
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
error_msg = f"{str(last_exception)} after {max_attempts} attempts" if last_exception else f"Failed after {max_attempts} attempts"
failed_info = {
"batch_number": batch_num,
"error": error_msg,
"batch_requests": [extract_custom_id(req) for req in batch_requests],
"input_file": input_path,
}
with open(error_filename, 'a', encoding='utf-8') as error_file:
error_file.write(json.dumps(failed_info) + "\n")
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
return
# Build a dictionary of the expected response fields.
api_response = {
"id": message_batch.id,
"type": message_batch.type,
"processing_status": message_batch.processing_status,
"request_counts": vars(message_batch.request_counts),
"ended_at": serialize_datetime(message_batch.ended_at),
"created_at": serialize_datetime(message_batch.created_at),
"expires_at": serialize_datetime(message_batch.expires_at),
"cancel_initiated_at": serialize_datetime(message_batch.cancel_initiated_at),
"results_url": message_batch.results_url,
}
batch_metadata = {
"batch_id": message_batch.id,
"api_response": api_response,
"custom_ids": custom_ids,
"input_file": os.path.basename(input_path),
}
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
with open(metadata_filename, 'a', encoding='utf-8') as meta_file:
meta_file.write(json.dumps(batch_metadata) + "\n")
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
if __name__ == "__main__":
# When running this module directly, submit the reasoning batches.
submit_reasoning_batches()

View file

@ -0,0 +1,202 @@
#!/usr/bin/env python3
"""
This script reads a JSONL file that contains messages with usage statistics.
For each JSON record, it expects to find the token usage information under:
record["result"]["message"]["usage"]
It then calculates and prints statistics for each usage token field:
- input_tokens
- cache_creation_input_tokens
- cache_read_input_tokens
- output_tokens
+pricing calculations
+calculates the savings from caching (vs if we hadn't done any caching)
+forecasts costs for 10,000, 20,000 and 50,000 jobs based on tokens per query
Usage:
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
"""
import json
import argparse
from statistics import mean
def main():
parser = argparse.ArgumentParser(
description="Compute usage token statistics from a JSONL file."
)
parser.add_argument(
"file", help="Path to the JSONL file containing usage token data."
)
args = parser.parse_args()
# Usage token fields that we want to track
usage_fields = [
"input_tokens",
"cache_creation_input_tokens",
"cache_read_input_tokens",
"output_tokens",
]
# Pricing for Sonnet, 2 Feb 2025
base_input_rate = 1.50
pricing = {
"input_tokens": base_input_rate,
"cache_creation_input_tokens": base_input_rate * 1.25, # More expensive for initial computation
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
"output_tokens": 7.50,
}
# A dictionary to store lists of values for each usage field
usage_data = {key: [] for key in usage_fields}
total_lines = 0
error_count = 0
with open(args.file, "r", encoding="utf-8") as f:
for line in f:
total_lines += 1
try:
record = json.loads(line)
except json.JSONDecodeError:
print(f"[Warning] Failed to parse JSON on line {total_lines}.")
error_count += 1
continue
# Navigate to the usage stats
try:
usage = record["result"]["message"]["usage"]
except KeyError:
print(f"[Warning] Missing usage field in line {total_lines}.")
error_count += 1
continue
# Extract token values from the usage data
for key in usage_fields:
# Defaulting to 0 if the token field is missing or non-numeric
try:
token_value = int(usage.get(key, 0))
except (ValueError, TypeError):
token_value = 0
usage_data[key].append(token_value)
print(f"\nProcessed {total_lines} lines with {error_count} error(s).\n")
print("Usage Tokens Statistics:")
print("-" * 40)
grand_total_cost = 0.0
# Calculate and print stats for each token type
for key in usage_fields:
values = usage_data[key]
if values:
total = sum(values)
count = len(values)
min_val = min(values)
max_val = max(values)
avg = mean(values)
# Calculate pricing cost scaling by tokens per million
cost = total / 1_000_000 * pricing[key]
grand_total_cost += cost
print(f"{key}:")
print(f" Total = {total}")
print(f" Count = {count}")
print(f" Min = {min_val}")
print(f" Max = {max_val}")
print(f" Mean = {avg:.2f}")
print(f" Cost = ${cost:.2f}\n")
else:
print(f"{key}: No data found.\n")
print("-" * 40)
print(f"Grand Total Estimated Cost: ${grand_total_cost:.2f}")
# -----------------------------------------------
# Calculate caching savings (for input-related tokens)
# Without caching, all tokens would have been charged at the standard input rate.
#
# Baseline cost (if no caching were used):
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
# / 1_000_000 * base_input_rate
#
# Actual cost (with caching):
# = input_tokens * base_input_rate +
# cache_creation_input_tokens * (base_input_rate * 1.25) +
# cache_read_input_tokens * (base_input_rate * 0.1)
#
# Savings from caching is then the difference.
sum_input = sum(usage_data["input_tokens"])
sum_cache_creation = sum(usage_data["cache_creation_input_tokens"])
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
actual_input_cost = (sum_input) / 1_000_000 * pricing["input_tokens"] \
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
caching_savings = baseline_input_cost - actual_input_cost
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
# -----------------------------------------------
# Forecast future cost estimates based on the average tokens per job.
#
# We'll compute the average tokens per job (i.e. tokens per query) for:
# - input_tokens
# - cache_creation_input_tokens
# - cache_read_input_tokens
# - output_tokens
#
# Then we forecast, for example, for 10,000, 20,000, and 50,000 jobs:
# - Apply the relevant pricing to compute the cost per token type.
# - Also compute the baseline cost for input-related tokens and the savings
# from caching.
if usage_data["input_tokens"]:
job_count = len(usage_data["input_tokens"])
avg_input_tokens = sum(usage_data["input_tokens"]) / job_count
avg_cache_creation_tokens = sum(usage_data["cache_creation_input_tokens"]) / job_count
avg_cache_read_tokens = sum(usage_data["cache_read_input_tokens"]) / job_count
avg_output_tokens = sum(usage_data["output_tokens"]) / job_count
print("\nAverage Tokens per Job:")
print(f" input_tokens = {avg_input_tokens:.2f}")
print(f" cache_creation_input_tokens = {avg_cache_creation_tokens:.2f}")
print(f" cache_read_input_tokens = {avg_cache_read_tokens:.2f}")
print(f" output_tokens = {avg_output_tokens:.2f}")
forecast_jobs = [2000, 4000, 10000, 20000, 50000]
print("\nForecasting Future Job Costs:")
for jobs in forecast_jobs:
# Forecast token usage for the job count by multiplying the per-job averages.
forecast_input = avg_input_tokens * jobs
forecast_cache_creation = avg_cache_creation_tokens * jobs
forecast_cache_read = avg_cache_read_tokens * jobs
forecast_output = avg_output_tokens * jobs
# Forecast actual cost (with caching applied for input tokens):
actual_input_cost_forecast = (forecast_input) / 1_000_000 * pricing["input_tokens"] \
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
# Without caching, all input-related tokens would be at base_input_rate:
baseline_input_cost_forecast = (forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
forecast_output_cost = forecast_output / 1_000_000 * pricing["output_tokens"]
total_forecast_cost = actual_input_cost_forecast + forecast_output_cost
print(f"\nFor {jobs:,} jobs:")
print(" Forecasted Token Usage:")
print(f" input_tokens = {forecast_input:,.0f}")
print(f" cache_creation_input_tokens = {forecast_cache_creation:,.0f}")
print(f" cache_read_input_tokens = {forecast_cache_read:,.0f}")
print(f" output_tokens = {forecast_output:,.0f}")
print(" Estimated Costs:")
print(f" Input Cost (with caching) = ${actual_input_cost_forecast:,.2f}")
print(f" Output Cost = ${forecast_output_cost:,.2f}")
print(f" Grand Total Cost = ${total_forecast_cost:,.2f}")
print(f" Caching Savings (input) = ${caching_savings_forecast:,.2f}")
else:
print("No valid jobs to forecast future costs.")
if __name__ == "__main__":
main()