mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
Completed: full example suite
This commit is contained in:
parent
c0a16d7f2b
commit
de7d37f14f
13 changed files with 1309 additions and 220 deletions
0
examples/word_ladder/utils/__init__.py
Normal file
0
examples/word_ladder/utils/__init__.py
Normal file
131
examples/word_ladder/utils/create_word_ladders.py
Normal file
131
examples/word_ladder/utils/create_word_ladders.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
create_word_ladders.py – Generates the word ladder dataset as a JSONL file.
|
||||
Each line is a JSON object with keys: question, answer, metadata, and reasoning (set to None).
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import reasoning_gym
|
||||
|
||||
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
||||
"""
|
||||
Check for duplicate word pairs in a word ladder JSONL file.
|
||||
|
||||
Returns:
|
||||
tuple[bool, dict]: (has_duplicates, valid_entries) where:
|
||||
- has_duplicates: True if any duplicates were found
|
||||
- valid_entries: Dict mapping line_number -> data for non-duplicate entries
|
||||
|
||||
Note: A pair is considered duplicate if either (word1, word2) or (word2, word1)
|
||||
already exists, since word ladder paths are bidirectional.
|
||||
"""
|
||||
pairs_seen = {} # (start, end) -> (line_number, data)
|
||||
valid_entries = {}
|
||||
duplicates_found = False
|
||||
|
||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
data = json.loads(line)
|
||||
metadata = data['metadata']
|
||||
pair = (metadata['start_word'], metadata['end_word'])
|
||||
reverse_pair = (metadata['end_word'], metadata['start_word'])
|
||||
|
||||
# Check both orientations of the pair
|
||||
if pair in pairs_seen or reverse_pair in pairs_seen:
|
||||
duplicates_found = True
|
||||
# Skip this entry - it's a duplicate
|
||||
continue
|
||||
else:
|
||||
# Store both the line number and data for valid entries
|
||||
pairs_seen[pair] = (line_num, data)
|
||||
valid_entries[line_num] = data
|
||||
|
||||
return duplicates_found, valid_entries
|
||||
|
||||
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
|
||||
"""
|
||||
Creates a word ladder dataset and writes each sample as a JSON line.
|
||||
Ensures no duplicate word pairs by regenerating as needed.
|
||||
"""
|
||||
if config is None:
|
||||
raise ValueError("Configuration (config) must be provided.")
|
||||
|
||||
# Create output directory if it doesn't exist.
|
||||
# Updated path to point to the parent folder's output.
|
||||
output_dir = Path(__file__).resolve().parent.parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Determine the file path based on uuid when not provided
|
||||
if jsonl_path is None:
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
|
||||
else:
|
||||
jsonl_path = Path(jsonl_path)
|
||||
|
||||
target_size = config['dataset_config']['size']
|
||||
current_size = 0
|
||||
max_attempts = 3 # Limit total regeneration attempts
|
||||
attempt = 0
|
||||
|
||||
# Initial generation
|
||||
dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
|
||||
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
||||
for item in tqdm(dataset, desc="Generating initial ladder examples"):
|
||||
row = {
|
||||
'question': item['question'],
|
||||
'answer': item['answer'],
|
||||
'reasoning': None,
|
||||
'metadata': item.get('metadata', {})
|
||||
}
|
||||
f.write(json.dumps(row) + '\n')
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Check entire file for duplicates
|
||||
has_duplicates, valid_entries = check_duplicates(jsonl_path)
|
||||
current_size = len(valid_entries)
|
||||
|
||||
if not has_duplicates and current_size == target_size:
|
||||
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
|
||||
return
|
||||
|
||||
# If we have duplicates or not enough entries, regenerate the missing amount
|
||||
needed = target_size - current_size
|
||||
if needed > 0:
|
||||
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
|
||||
|
||||
# Generate additional examples
|
||||
config['dataset_config']['size'] = needed
|
||||
additional_dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
|
||||
|
||||
# Write all entries to a temporary file
|
||||
temp_path = jsonl_path.with_suffix('.tmp')
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
# Write existing valid entries
|
||||
for data in valid_entries.values():
|
||||
f.write(json.dumps(data) + '\n')
|
||||
|
||||
# Write new entries
|
||||
for item in additional_dataset:
|
||||
row = {
|
||||
'question': item['question'],
|
||||
'answer': item['answer'],
|
||||
'reasoning': None,
|
||||
'metadata': item.get('metadata', {})
|
||||
}
|
||||
f.write(json.dumps(row) + '\n')
|
||||
|
||||
# Replace original file with temporary file
|
||||
temp_path.replace(jsonl_path)
|
||||
|
||||
# Note: We'll check for duplicates again at the start of the next loop
|
||||
|
||||
attempt += 1
|
||||
|
||||
if current_size < target_size:
|
||||
print(f"\nWarning: Could only generate {current_size} unique examples after {max_attempts} attempts.")
|
||||
else:
|
||||
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
|
||||
211
examples/word_ladder/utils/generate_reasoning.py
Normal file
211
examples/word_ladder/utils/generate_reasoning.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""
|
||||
generate_reasoning.py – Reads the JSONL file containing ladder examples,
|
||||
creates batch requests of chain-of-thought prompts split into batches of 2,500,
|
||||
calls Anthropic's Message Batches API for each batch, and writes separate batch metadata
|
||||
files for later retrieval of the responses.
|
||||
|
||||
*** WARNING ***: Running large batches of requests via the Anthropic API (especially in generate_reasoning.py)
|
||||
can incur significant costs in Anthropic credits. Please review and understand your API quota and budgeting
|
||||
before running the API call. If you are testing or working with a demo dataset, adjust the batch size or dataset
|
||||
size appropriately to avoid unexpected charges.
|
||||
|
||||
Using Anthropic's Message Batches API with caching enabled for system prompt.
|
||||
In our informal testing, Sonnet was deemed best performance value.
|
||||
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import anthropic
|
||||
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
||||
from anthropic.types.messages.batch_create_params import Request
|
||||
|
||||
# Updated default output directory to use the parent directory.
|
||||
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
|
||||
|
||||
# Add default constants at the top with other constants
|
||||
DEFAULT_INPUT_JSONL = "output/word_ladder_examples.jsonl"
|
||||
DEFAULT_SYSTEM_PROMPT = Path(__file__).resolve().parent.parent / "system_prompt.txt"
|
||||
BATCH_SIZE = 2500
|
||||
COMMON_UUID = uuid.uuid4().hex[:8]
|
||||
|
||||
# Set up the Anthropic client (ensure the API key is set in the environment)
|
||||
client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY'])
|
||||
|
||||
def submit_reasoning_batches(
|
||||
input_path: str = DEFAULT_INPUT_JSONL,
|
||||
batch_metadata_prefix: str = "batch_metadata",
|
||||
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT
|
||||
) -> None:
|
||||
"""
|
||||
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
|
||||
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
|
||||
Anthropic's Message Batches API.
|
||||
|
||||
Args:
|
||||
input_path: Path to input JSONL file
|
||||
batch_metadata_prefix: Prefix for batch metadata files
|
||||
system_prompt_path: Path to system prompt file
|
||||
"""
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Read the system prompt from file (used as a preamble for every request)
|
||||
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
|
||||
system_message = [{
|
||||
"type": "text",
|
||||
"text": sys_file.read(),
|
||||
"cache_control": {"type": "ephemeral"} # Enable anthropic prompt caching
|
||||
}]
|
||||
batch_requests = []
|
||||
custom_ids = [] # List of custom_ids for the current batch
|
||||
batch_num = 0
|
||||
|
||||
# Get the total number of lines in advance for tqdm progress bar.
|
||||
total_lines = sum(1 for _ in open(input_path))
|
||||
|
||||
with open(input_path, 'r', encoding="utf-8") as infile:
|
||||
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
|
||||
data = json.loads(line)
|
||||
|
||||
# Skip example if 'reasoning' already exists.
|
||||
if not data.get('reasoning'):
|
||||
# Build a custom id. Here we use the row position and the start/end words:
|
||||
metadata = data.get("metadata", {})
|
||||
start = metadata.get("start_word", "unknown")
|
||||
end = metadata.get("end_word", "unknown")
|
||||
custom_id = f"{start}_{end}_{idx}"
|
||||
custom_ids.append(custom_id)
|
||||
|
||||
# Build the prompt text exactly as before.
|
||||
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
|
||||
|
||||
# Build the request payload using Request and MessageCreateParamsNonStreaming.
|
||||
request_payload = Request(
|
||||
custom_id=custom_id,
|
||||
params=MessageCreateParamsNonStreaming(
|
||||
model="claude-3-5-sonnet-20241022", # Or choose the appropriate model version
|
||||
max_tokens=8192,
|
||||
temperature=0.5,
|
||||
system=system_message,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
)
|
||||
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
|
||||
if isinstance(request_payload, dict):
|
||||
request_payload["custom_id"] = custom_id
|
||||
else:
|
||||
request_payload.custom_id = custom_id
|
||||
batch_requests.append(request_payload)
|
||||
|
||||
# If we have reached our batch size limit, submit the current batch.
|
||||
if len(batch_requests) >= BATCH_SIZE:
|
||||
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
|
||||
batch_num += 1
|
||||
# Reset for the next batch
|
||||
batch_requests = []
|
||||
custom_ids = []
|
||||
|
||||
# Submit any remaining requests that didn't complete a full batch.
|
||||
if batch_requests:
|
||||
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
|
||||
|
||||
|
||||
def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path):
|
||||
"""
|
||||
Helper function to submit a single batch request, track the full API response,
|
||||
and write out the corresponding metadata including the list of custom_ids.
|
||||
"""
|
||||
# Use the default output directory
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def serialize_datetime(dt):
|
||||
"""
|
||||
Convert a datetime object to ISO formatted string.
|
||||
If dt is None, returns None.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
iso_str = dt.isoformat() # e.g. "2024-08-20T18:37:24.100435+00:00"
|
||||
|
||||
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
|
||||
iso_str = iso_str.replace("+00:00", "Z")
|
||||
return iso_str
|
||||
|
||||
def extract_custom_id(req):
|
||||
# Safely extract the custom_id attribute whether req is an object or a dict.
|
||||
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
|
||||
|
||||
max_attempts = 2
|
||||
attempt = 0
|
||||
last_exception = None
|
||||
message_batch = None
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
|
||||
message_batch = client.messages.batches.create(
|
||||
requests=batch_requests
|
||||
)
|
||||
time.sleep(1)
|
||||
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
|
||||
break # Success: exit the loop.
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
attempt += 1
|
||||
print(f"Error submitting batch {batch_num} on attempt {attempt}: {e}")
|
||||
if attempt < max_attempts:
|
||||
print("Retrying...")
|
||||
time.sleep(1)
|
||||
|
||||
if message_batch is None:
|
||||
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
|
||||
error_msg = f"{str(last_exception)} after {max_attempts} attempts" if last_exception else f"Failed after {max_attempts} attempts"
|
||||
failed_info = {
|
||||
"batch_number": batch_num,
|
||||
"error": error_msg,
|
||||
"batch_requests": [extract_custom_id(req) for req in batch_requests],
|
||||
"input_file": input_path,
|
||||
}
|
||||
with open(error_filename, 'a', encoding='utf-8') as error_file:
|
||||
error_file.write(json.dumps(failed_info) + "\n")
|
||||
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
|
||||
return
|
||||
|
||||
# Build a dictionary of the expected response fields.
|
||||
api_response = {
|
||||
"id": message_batch.id,
|
||||
"type": message_batch.type,
|
||||
"processing_status": message_batch.processing_status,
|
||||
"request_counts": vars(message_batch.request_counts),
|
||||
"ended_at": serialize_datetime(message_batch.ended_at),
|
||||
"created_at": serialize_datetime(message_batch.created_at),
|
||||
"expires_at": serialize_datetime(message_batch.expires_at),
|
||||
"cancel_initiated_at": serialize_datetime(message_batch.cancel_initiated_at),
|
||||
"results_url": message_batch.results_url,
|
||||
}
|
||||
|
||||
batch_metadata = {
|
||||
"batch_id": message_batch.id,
|
||||
"api_response": api_response,
|
||||
"custom_ids": custom_ids,
|
||||
"input_file": os.path.basename(input_path),
|
||||
}
|
||||
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||
with open(metadata_filename, 'a', encoding='utf-8') as meta_file:
|
||||
meta_file.write(json.dumps(batch_metadata) + "\n")
|
||||
|
||||
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# When running this module directly, submit the reasoning batches.
|
||||
submit_reasoning_batches()
|
||||
202
examples/word_ladder/utils/usage_stats.py
Normal file
202
examples/word_ladder/utils/usage_stats.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script reads a JSONL file that contains messages with usage statistics.
|
||||
For each JSON record, it expects to find the token usage information under:
|
||||
record["result"]["message"]["usage"]
|
||||
|
||||
It then calculates and prints statistics for each usage token field:
|
||||
- input_tokens
|
||||
- cache_creation_input_tokens
|
||||
- cache_read_input_tokens
|
||||
- output_tokens
|
||||
+pricing calculations
|
||||
+calculates the savings from caching (vs if we hadn't done any caching)
|
||||
+forecasts costs for 10,000, 20,000 and 50,000 jobs based on tokens per query
|
||||
|
||||
Usage:
|
||||
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
from statistics import mean
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute usage token statistics from a JSONL file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"file", help="Path to the JSONL file containing usage token data."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Usage token fields that we want to track
|
||||
usage_fields = [
|
||||
"input_tokens",
|
||||
"cache_creation_input_tokens",
|
||||
"cache_read_input_tokens",
|
||||
"output_tokens",
|
||||
]
|
||||
|
||||
# Pricing for Sonnet, 2 Feb 2025
|
||||
base_input_rate = 1.50
|
||||
pricing = {
|
||||
"input_tokens": base_input_rate,
|
||||
"cache_creation_input_tokens": base_input_rate * 1.25, # More expensive for initial computation
|
||||
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
|
||||
"output_tokens": 7.50,
|
||||
}
|
||||
|
||||
# A dictionary to store lists of values for each usage field
|
||||
usage_data = {key: [] for key in usage_fields}
|
||||
total_lines = 0
|
||||
error_count = 0
|
||||
|
||||
with open(args.file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
total_lines += 1
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(f"[Warning] Failed to parse JSON on line {total_lines}.")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Navigate to the usage stats
|
||||
try:
|
||||
usage = record["result"]["message"]["usage"]
|
||||
except KeyError:
|
||||
print(f"[Warning] Missing usage field in line {total_lines}.")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Extract token values from the usage data
|
||||
for key in usage_fields:
|
||||
# Defaulting to 0 if the token field is missing or non-numeric
|
||||
try:
|
||||
token_value = int(usage.get(key, 0))
|
||||
except (ValueError, TypeError):
|
||||
token_value = 0
|
||||
usage_data[key].append(token_value)
|
||||
|
||||
print(f"\nProcessed {total_lines} lines with {error_count} error(s).\n")
|
||||
print("Usage Tokens Statistics:")
|
||||
print("-" * 40)
|
||||
|
||||
grand_total_cost = 0.0
|
||||
# Calculate and print stats for each token type
|
||||
for key in usage_fields:
|
||||
values = usage_data[key]
|
||||
if values:
|
||||
total = sum(values)
|
||||
count = len(values)
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg = mean(values)
|
||||
# Calculate pricing cost scaling by tokens per million
|
||||
cost = total / 1_000_000 * pricing[key]
|
||||
grand_total_cost += cost
|
||||
|
||||
print(f"{key}:")
|
||||
print(f" Total = {total}")
|
||||
print(f" Count = {count}")
|
||||
print(f" Min = {min_val}")
|
||||
print(f" Max = {max_val}")
|
||||
print(f" Mean = {avg:.2f}")
|
||||
print(f" Cost = ${cost:.2f}\n")
|
||||
else:
|
||||
print(f"{key}: No data found.\n")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Grand Total Estimated Cost: ${grand_total_cost:.2f}")
|
||||
|
||||
# -----------------------------------------------
|
||||
# Calculate caching savings (for input-related tokens)
|
||||
# Without caching, all tokens would have been charged at the standard input rate.
|
||||
#
|
||||
# Baseline cost (if no caching were used):
|
||||
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
|
||||
# / 1_000_000 * base_input_rate
|
||||
#
|
||||
# Actual cost (with caching):
|
||||
# = input_tokens * base_input_rate +
|
||||
# cache_creation_input_tokens * (base_input_rate * 1.25) +
|
||||
# cache_read_input_tokens * (base_input_rate * 0.1)
|
||||
#
|
||||
# Savings from caching is then the difference.
|
||||
sum_input = sum(usage_data["input_tokens"])
|
||||
sum_cache_creation = sum(usage_data["cache_creation_input_tokens"])
|
||||
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
|
||||
|
||||
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||
actual_input_cost = (sum_input) / 1_000_000 * pricing["input_tokens"] \
|
||||
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
|
||||
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||
caching_savings = baseline_input_cost - actual_input_cost
|
||||
|
||||
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
|
||||
|
||||
# -----------------------------------------------
|
||||
# Forecast future cost estimates based on the average tokens per job.
|
||||
#
|
||||
# We'll compute the average tokens per job (i.e. tokens per query) for:
|
||||
# - input_tokens
|
||||
# - cache_creation_input_tokens
|
||||
# - cache_read_input_tokens
|
||||
# - output_tokens
|
||||
#
|
||||
# Then we forecast, for example, for 10,000, 20,000, and 50,000 jobs:
|
||||
# - Apply the relevant pricing to compute the cost per token type.
|
||||
# - Also compute the baseline cost for input-related tokens and the savings
|
||||
# from caching.
|
||||
if usage_data["input_tokens"]:
|
||||
job_count = len(usage_data["input_tokens"])
|
||||
avg_input_tokens = sum(usage_data["input_tokens"]) / job_count
|
||||
avg_cache_creation_tokens = sum(usage_data["cache_creation_input_tokens"]) / job_count
|
||||
avg_cache_read_tokens = sum(usage_data["cache_read_input_tokens"]) / job_count
|
||||
avg_output_tokens = sum(usage_data["output_tokens"]) / job_count
|
||||
|
||||
print("\nAverage Tokens per Job:")
|
||||
print(f" input_tokens = {avg_input_tokens:.2f}")
|
||||
print(f" cache_creation_input_tokens = {avg_cache_creation_tokens:.2f}")
|
||||
print(f" cache_read_input_tokens = {avg_cache_read_tokens:.2f}")
|
||||
print(f" output_tokens = {avg_output_tokens:.2f}")
|
||||
|
||||
forecast_jobs = [2000, 4000, 10000, 20000, 50000]
|
||||
print("\nForecasting Future Job Costs:")
|
||||
for jobs in forecast_jobs:
|
||||
# Forecast token usage for the job count by multiplying the per-job averages.
|
||||
forecast_input = avg_input_tokens * jobs
|
||||
forecast_cache_creation = avg_cache_creation_tokens * jobs
|
||||
forecast_cache_read = avg_cache_read_tokens * jobs
|
||||
forecast_output = avg_output_tokens * jobs
|
||||
|
||||
# Forecast actual cost (with caching applied for input tokens):
|
||||
actual_input_cost_forecast = (forecast_input) / 1_000_000 * pricing["input_tokens"] \
|
||||
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
|
||||
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||
|
||||
# Without caching, all input-related tokens would be at base_input_rate:
|
||||
baseline_input_cost_forecast = (forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||
|
||||
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
|
||||
|
||||
forecast_output_cost = forecast_output / 1_000_000 * pricing["output_tokens"]
|
||||
total_forecast_cost = actual_input_cost_forecast + forecast_output_cost
|
||||
|
||||
print(f"\nFor {jobs:,} jobs:")
|
||||
print(" Forecasted Token Usage:")
|
||||
print(f" input_tokens = {forecast_input:,.0f}")
|
||||
print(f" cache_creation_input_tokens = {forecast_cache_creation:,.0f}")
|
||||
print(f" cache_read_input_tokens = {forecast_cache_read:,.0f}")
|
||||
print(f" output_tokens = {forecast_output:,.0f}")
|
||||
print(" Estimated Costs:")
|
||||
print(f" Input Cost (with caching) = ${actual_input_cost_forecast:,.2f}")
|
||||
print(f" Output Cost = ${forecast_output_cost:,.2f}")
|
||||
print(f" Grand Total Cost = ${total_forecast_cost:,.2f}")
|
||||
print(f" Caching Savings (input) = ${caching_savings_forecast:,.2f}")
|
||||
else:
|
||||
print("No valid jobs to forecast future costs.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue