This commit is contained in:
Cavit Erginsoy 2025-02-03 11:35:30 +00:00
parent 1e27021e11
commit 6c564b3dd9
13 changed files with 305 additions and 317 deletions

View file

@ -129,4 +129,3 @@ The dataset generation parameters are centralized in `main.py` under the `config
## License ## License
This project is licensed under the MIT License. This project is licensed under the MIT License.

View file

@ -6,10 +6,10 @@ main.py Orchestrates the overall flow:
3. Upload the final dataset to HuggingFace Hub (if needed) 3. Upload the final dataset to HuggingFace Hub (if needed)
""" """
import uuid
import sys import sys
import uuid
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Any, Dict
from examples.word_ladder.utils import create_word_ladders, generate_reasoning from examples.word_ladder.utils import create_word_ladders, generate_reasoning
@ -40,17 +40,18 @@ def create_dataset(jsonl_path: Path, config: Dict[str, Any]) -> bool:
print(f"\nError: Failed to create dataset: {str(e)}") print(f"\nError: Failed to create dataset: {str(e)}")
return False return False
def main(): def main():
# Centralized configuration for the dataset # Centralized configuration for the dataset
config = { config = {
'dataset_name': 'word_ladder', "dataset_name": "word_ladder",
'dataset_config': { "dataset_config": {
'min_word_length': 3, "min_word_length": 3,
'max_word_length': 5, "max_word_length": 3,
'min_chain_length':-1, # set to -1 for the shortest possible path "min_chain_length": -1, # set to -1 for the shortest possible path
'max_chain_length':10, "max_chain_length": 7,
'size': 100, # Generate a small-ish dataset for demonstration "size": 2000, # Generate a small-ish dataset for demonstration
} },
} }
# Generate a friendly unique identifier and compose the file path # Generate a friendly unique identifier and compose the file path
@ -64,21 +65,20 @@ def main():
print("Exiting due to dataset creation failure.") print("Exiting due to dataset creation failure.")
sys.exit(1) sys.exit(1)
# Step 2: Generate reasoning # Step 2: Generate reasoning
'''
try: try:
print("\nStep 2: Submitting reasoning batches for the dataset...") print("\nStep 2: Submitting reasoning batches for the dataset...")
generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path)) generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path))
except Exception as e: except Exception as e:
print(f"\nError: Failed to submit reasoning batches: {str(e)}") print(f"\nError: Failed to submit reasoning batches: {str(e)}")
sys.exit(1) sys.exit(1)
'''
# Step 3: Check Anthropic batch results # Step 3: Check Anthropic batch results
# Step 4: Upload to HuggingFace 🤗 # Step 4: Upload to HuggingFace 🤗
print("\nComplete!") print("\nComplete!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -3,4 +3,3 @@
anthropic>=0.45.2 # Client library for interacting with Anthropic's API anthropic>=0.45.2 # Client library for interacting with Anthropic's API
tqdm>=4.67.1 # For progress bars tqdm>=4.67.1 # For progress bars

View file

@ -1,11 +1,12 @@
import os
import json
import time
import datetime import datetime
import pytest import json
import os
import time
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from examples.word_ladder import generate_reasoning from examples.word_ladder import generate_reasoning
# We alias the functions and globals for easier usage in our tests. # We alias the functions and globals for easier usage in our tests.
@ -16,6 +17,7 @@ COMMON_UUID = generate_reasoning.COMMON_UUID
BATCH_SIZE = generate_reasoning.BATCH_SIZE BATCH_SIZE = generate_reasoning.BATCH_SIZE
client = generate_reasoning.client client = generate_reasoning.client
# Define a mock batch response class mimicking Anthropic's API response. # Define a mock batch response class mimicking Anthropic's API response.
class MockBatchResponse: class MockBatchResponse:
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False): def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
@ -23,13 +25,7 @@ class MockBatchResponse:
self.type = "message_batch" self.type = "message_batch"
self.processing_status = processing_status self.processing_status = processing_status
# Make request_counts a SimpleNamespace object with the required attributes # Make request_counts a SimpleNamespace object with the required attributes
self.request_counts = SimpleNamespace( self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
processing=0,
succeeded=0,
errored=0,
canceled=0,
expired=0
)
self.ended_at = None self.ended_at = None
# Use datetime objects so that isoformat() is available # Use datetime objects so that isoformat() is available
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
@ -37,6 +33,7 @@ class MockBatchResponse:
self.cancel_initiated_at = None self.cancel_initiated_at = None
self.results_url = None self.results_url = None
# Helper: Create a temporary system prompt file. # Helper: Create a temporary system prompt file.
@pytest.fixture @pytest.fixture
def system_prompt_file(tmp_path, monkeypatch): def system_prompt_file(tmp_path, monkeypatch):
@ -48,6 +45,7 @@ def system_prompt_file(tmp_path, monkeypatch):
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file)) monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
return sys_file return sys_file
# Helper: Create necessary directories using a temporary location. # Helper: Create necessary directories using a temporary location.
@pytest.fixture @pytest.fixture
def setup_directories(tmp_path, monkeypatch): def setup_directories(tmp_path, monkeypatch):
@ -62,18 +60,31 @@ def setup_directories(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)
return output_dir return output_dir
# Helper: Create a temporary input JSONL file with given entries. # Helper: Create a temporary input JSONL file with given entries.
@pytest.fixture @pytest.fixture
def input_jsonl_file(tmp_path, setup_directories, monkeypatch): def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
# Create input file in temporary directory # Create input file in temporary directory
file_path = setup_directories / "word_ladder_examples.jsonl" file_path = setup_directories / "word_ladder_examples.jsonl"
entries = [ entries = [
{ "question": "Transform 'A' to 'B'", "answer": "A,X,B", "reasoning": None, {
"metadata": { "start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3 } }, "question": "Transform 'A' to 'B'",
{ "question": "Transform 'C' to 'D'", "answer": "C,Y,D", "reasoning": "Some reasoning", "answer": "A,X,B",
"metadata": { "start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3 } }, "reasoning": None,
{ "question": "Transform 'E' to 'F'", "answer": "E,Z,F", "reasoning": None, "metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
"metadata": { "start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3 } } },
{
"question": "Transform 'C' to 'D'",
"answer": "C,Y,D",
"reasoning": "Some reasoning",
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
},
{
"question": "Transform 'E' to 'F'",
"answer": "E,Z,F",
"reasoning": None,
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
},
] ]
with file_path.open("w", encoding="utf-8") as f: with file_path.open("w", encoding="utf-8") as f:
for entry in entries: for entry in entries:
@ -83,6 +94,7 @@ def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path)) monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
return file_path return file_path
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning. # Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch): def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
def fake_create(requests): def fake_create(requests):
@ -116,8 +128,7 @@ def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_dire
monkeypatch.setattr(client.messages.batches, "create", fake_create) monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_metadata" batch_metadata_prefix = "test_metadata"
submit_reasoning_batches(input_path=str(input_jsonl_file), submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl" metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename meta_file_path = setup_directories / metadata_filename
@ -136,9 +147,11 @@ def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_dire
custom_ids = metadata["custom_ids"] custom_ids = metadata["custom_ids"]
assert len(custom_ids) == 2 assert len(custom_ids) == 2
# Test that _submit_single_batch retries once and eventually succeeds. # Test that _submit_single_batch retries once and eventually succeeds.
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch): def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
call_count = {"count": 0} call_count = {"count": 0}
def fake_create_retry(requests): def fake_create_retry(requests):
if call_count["count"] == 0: if call_count["count"] == 0:
call_count["count"] += 1 call_count["count"] += 1
@ -163,6 +176,7 @@ def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
assert call_count["count"] == 1 assert call_count["count"] == 1
# Test that when all attempts to submit a batch fail, the error is logged to the failed file. # Test that when all attempts to submit a batch fail, the error is logged to the failed file.
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch): def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
def fake_create_fail(requests): def fake_create_fail(requests):
@ -186,6 +200,7 @@ def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
assert "Permanent failure" in error_entry["error"] assert "Permanent failure" in error_entry["error"]
assert error_entry["batch_requests"] == ["dummy_fail"] assert error_entry["batch_requests"] == ["dummy_fail"]
# Test batching behavior when multiple batches are needed. # Test batching behavior when multiple batches are needed.
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch): def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
test_batch_size = 2 test_batch_size = 2
@ -198,7 +213,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
"question": f"Transform word ladder {idx}", "question": f"Transform word ladder {idx}",
"answer": f"start,mid,end_{idx}", "answer": f"start,mid,end_{idx}",
"reasoning": None, "reasoning": None,
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"} "metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
} }
for idx in range(5) for idx in range(5)
] ]
@ -211,6 +226,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file)) monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
batch_ids = [] batch_ids = []
def fake_create(requests): def fake_create(requests):
new_id = f"msgbatch_batch_{len(batch_ids)}" new_id = f"msgbatch_batch_{len(batch_ids)}"
batch_ids.append(new_id) batch_ids.append(new_id)
@ -219,8 +235,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
monkeypatch.setattr(client.messages.batches, "create", fake_create) monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_multi" batch_metadata_prefix = "test_multi"
submit_reasoning_batches(input_path=str(input_file), submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl" metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename meta_file_path = setup_directories / metadata_filename

View file

@ -11,6 +11,7 @@ from tqdm import tqdm
import reasoning_gym import reasoning_gym
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]: def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
""" """
Check for duplicate word pairs in a word ladder JSONL file. Check for duplicate word pairs in a word ladder JSONL file.
@ -27,12 +28,12 @@ def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
valid_entries = {} valid_entries = {}
duplicates_found = False duplicates_found = False
with open(jsonl_path, 'r', encoding='utf-8') as f: with open(jsonl_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f): for line_num, line in enumerate(f):
data = json.loads(line) data = json.loads(line)
metadata = data['metadata'] metadata = data["metadata"]
pair = (metadata['start_word'], metadata['end_word']) pair = (metadata["start_word"], metadata["end_word"])
reverse_pair = (metadata['end_word'], metadata['start_word']) reverse_pair = (metadata["end_word"], metadata["start_word"])
# Check both orientations of the pair # Check both orientations of the pair
if pair in pairs_seen or reverse_pair in pairs_seen: if pair in pairs_seen or reverse_pair in pairs_seen:
@ -46,6 +47,7 @@ def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
return duplicates_found, valid_entries return duplicates_found, valid_entries
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None: def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
""" """
Creates a word ladder dataset and writes each sample as a JSON line. Creates a word ladder dataset and writes each sample as a JSON line.
@ -66,22 +68,22 @@ def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> N
else: else:
jsonl_path = Path(jsonl_path) jsonl_path = Path(jsonl_path)
target_size = config['dataset_config']['size'] target_size = config["dataset_config"]["size"]
current_size = 0 current_size = 0
max_attempts = 3 # Limit total regeneration attempts max_attempts = 3 # Limit total regeneration attempts
attempt = 0 attempt = 0
# Initial generation # Initial generation
dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config']) dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
with open(jsonl_path, 'w', encoding='utf-8') as f: with open(jsonl_path, "w", encoding="utf-8") as f:
for item in tqdm(dataset, desc="Generating initial ladder examples"): for item in tqdm(dataset, desc="Generating initial ladder examples"):
row = { row = {
'question': item['question'], "question": item["question"],
'answer': item['answer'], "answer": item["answer"],
'reasoning': None, "reasoning": None,
'metadata': item.get('metadata', {}) "metadata": item.get("metadata", {}),
} }
f.write(json.dumps(row) + '\n') f.write(json.dumps(row) + "\n")
while attempt < max_attempts: while attempt < max_attempts:
# Check entire file for duplicates # Check entire file for duplicates
@ -98,25 +100,25 @@ def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> N
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...") print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
# Generate additional examples # Generate additional examples
config['dataset_config']['size'] = needed config["dataset_config"]["size"] = needed
additional_dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config']) additional_dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
# Write all entries to a temporary file # Write all entries to a temporary file
temp_path = jsonl_path.with_suffix('.tmp') temp_path = jsonl_path.with_suffix(".tmp")
with open(temp_path, 'w', encoding='utf-8') as f: with open(temp_path, "w", encoding="utf-8") as f:
# Write existing valid entries # Write existing valid entries
for data in valid_entries.values(): for data in valid_entries.values():
f.write(json.dumps(data) + '\n') f.write(json.dumps(data) + "\n")
# Write new entries # Write new entries
for item in additional_dataset: for item in additional_dataset:
row = { row = {
'question': item['question'], "question": item["question"],
'answer': item['answer'], "answer": item["answer"],
'reasoning': None, "reasoning": None,
'metadata': item.get('metadata', {}) "metadata": item.get("metadata", {}),
} }
f.write(json.dumps(row) + '\n') f.write(json.dumps(row) + "\n")
# Replace original file with temporary file # Replace original file with temporary file
temp_path.replace(jsonl_path) temp_path.replace(jsonl_path)

View file

@ -14,17 +14,16 @@ In our informal testing, Sonnet was deemed best performance value.
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code. You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
""" """
import os
import json import json
import uuid import os
import time import time
import uuid
from pathlib import Path from pathlib import Path
from tqdm import tqdm
import anthropic import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request from anthropic.types.messages.batch_create_params import Request
from tqdm import tqdm
# Updated default output directory to use the parent directory. # Updated default output directory to use the parent directory.
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output" DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
@ -36,12 +35,13 @@ BATCH_SIZE = 2500
COMMON_UUID = uuid.uuid4().hex[:8] COMMON_UUID = uuid.uuid4().hex[:8]
# Set up the Anthropic client (ensure the API key is set in the environment) # Set up the Anthropic client (ensure the API key is set in the environment)
client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY']) client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def submit_reasoning_batches( def submit_reasoning_batches(
input_path: str = DEFAULT_INPUT_JSONL, input_path: str = DEFAULT_INPUT_JSONL,
batch_metadata_prefix: str = "batch_metadata", batch_metadata_prefix: str = "batch_metadata",
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
) -> None: ) -> None:
""" """
Reads the input JSONL file of word ladder examples, builds batch requests for any example that Reads the input JSONL file of word ladder examples, builds batch requests for any example that
@ -59,11 +59,13 @@ def submit_reasoning_batches(
# Read the system prompt from file (used as a preamble for every request) # Read the system prompt from file (used as a preamble for every request)
with open(system_prompt_path, "r", encoding="utf-8") as sys_file: with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
system_message = [{ system_message = [
{
"type": "text", "type": "text",
"text": sys_file.read(), "text": sys_file.read(),
"cache_control": {"type": "ephemeral"} # Enable anthropic prompt caching "cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
}] }
]
batch_requests = [] batch_requests = []
custom_ids = [] # List of custom_ids for the current batch custom_ids = [] # List of custom_ids for the current batch
batch_num = 0 batch_num = 0
@ -71,12 +73,12 @@ def submit_reasoning_batches(
# Get the total number of lines in advance for tqdm progress bar. # Get the total number of lines in advance for tqdm progress bar.
total_lines = sum(1 for _ in open(input_path)) total_lines = sum(1 for _ in open(input_path))
with open(input_path, 'r', encoding="utf-8") as infile: with open(input_path, "r", encoding="utf-8") as infile:
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines): for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
data = json.loads(line) data = json.loads(line)
# Skip example if 'reasoning' already exists. # Skip example if 'reasoning' already exists.
if not data.get('reasoning'): if not data.get("reasoning"):
# Build a custom id. Here we use the row position and the start/end words: # Build a custom id. Here we use the row position and the start/end words:
metadata = data.get("metadata", {}) metadata = data.get("metadata", {})
start = metadata.get("start_word", "unknown") start = metadata.get("start_word", "unknown")
@ -95,10 +97,8 @@ def submit_reasoning_batches(
max_tokens=8192, max_tokens=8192,
temperature=0.5, temperature=0.5,
system=system_message, system=system_message,
messages=[ messages=[{"role": "user", "content": prompt}],
{"role": "user", "content": prompt} ),
]
)
) )
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set. # Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
if isinstance(request_payload, dict): if isinstance(request_payload, dict):
@ -153,9 +153,7 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
while attempt < max_attempts: while attempt < max_attempts:
try: try:
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})") print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
message_batch = client.messages.batches.create( message_batch = client.messages.batches.create(requests=batch_requests)
requests=batch_requests
)
time.sleep(1) time.sleep(1)
print(f"Batch {batch_num} submitted with ID: {message_batch.id}") print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
break # Success: exit the loop. break # Success: exit the loop.
@ -169,14 +167,18 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
if message_batch is None: if message_batch is None:
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl" error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
error_msg = f"{str(last_exception)} after {max_attempts} attempts" if last_exception else f"Failed after {max_attempts} attempts" error_msg = (
f"{str(last_exception)} after {max_attempts} attempts"
if last_exception
else f"Failed after {max_attempts} attempts"
)
failed_info = { failed_info = {
"batch_number": batch_num, "batch_number": batch_num,
"error": error_msg, "error": error_msg,
"batch_requests": [extract_custom_id(req) for req in batch_requests], "batch_requests": [extract_custom_id(req) for req in batch_requests],
"input_file": input_path, "input_file": input_path,
} }
with open(error_filename, 'a', encoding='utf-8') as error_file: with open(error_filename, "a", encoding="utf-8") as error_file:
error_file.write(json.dumps(failed_info) + "\n") error_file.write(json.dumps(failed_info) + "\n")
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.") print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
return return
@ -201,11 +203,12 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
"input_file": os.path.basename(input_path), "input_file": os.path.basename(input_path),
} }
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl" metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
with open(metadata_filename, 'a', encoding='utf-8') as meta_file: with open(metadata_filename, "a", encoding="utf-8") as meta_file:
meta_file.write(json.dumps(batch_metadata) + "\n") meta_file.write(json.dumps(batch_metadata) + "\n")
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.") print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
if __name__ == "__main__": if __name__ == "__main__":
# When running this module directly, submit the reasoning batches. # When running this module directly, submit the reasoning batches.
submit_reasoning_batches() submit_reasoning_batches()

View file

@ -17,17 +17,14 @@ Usage:
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
""" """
import json
import argparse import argparse
import json
from statistics import mean from statistics import mean
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Compute usage token statistics from a JSONL file.")
description="Compute usage token statistics from a JSONL file." parser.add_argument("file", help="Path to the JSONL file containing usage token data.")
)
parser.add_argument(
"file", help="Path to the JSONL file containing usage token data."
)
args = parser.parse_args() args = parser.parse_args()
# Usage token fields that we want to track # Usage token fields that we want to track
@ -129,9 +126,11 @@ def main():
sum_cache_read = sum(usage_data["cache_read_input_tokens"]) sum_cache_read = sum(usage_data["cache_read_input_tokens"])
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"] baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
actual_input_cost = (sum_input) / 1_000_000 * pricing["input_tokens"] \ actual_input_cost = (
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \ (sum_input) / 1_000_000 * pricing["input_tokens"]
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"] + (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
caching_savings = baseline_input_cost - actual_input_cost caching_savings = baseline_input_cost - actual_input_cost
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}") print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
@ -172,12 +171,16 @@ def main():
forecast_output = avg_output_tokens * jobs forecast_output = avg_output_tokens * jobs
# Forecast actual cost (with caching applied for input tokens): # Forecast actual cost (with caching applied for input tokens):
actual_input_cost_forecast = (forecast_input) / 1_000_000 * pricing["input_tokens"] \ actual_input_cost_forecast = (
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \ (forecast_input) / 1_000_000 * pricing["input_tokens"]
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"] + (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
# Without caching, all input-related tokens would be at base_input_rate: # Without caching, all input-related tokens would be at base_input_rate:
baseline_input_cost_forecast = (forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"] baseline_input_cost_forecast = (
(forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
)
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
@ -198,5 +201,6 @@ def main():
else: else:
print("No valid jobs to forecast future costs.") print("No valid jobs to forecast future costs.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -52,8 +52,10 @@ class WordLadderConfig:
return 3 <= length <= self.max_chain_length return 3 <= length <= self.max_chain_length
# Otherwise check against both min and max # Otherwise check against both min and max
return (self.min_chain_length <= length <= return (
(self.max_chain_length if self.max_chain_length != -1 else float('inf'))) self.min_chain_length <= length <= (self.max_chain_length if self.max_chain_length != -1 else float("inf"))
)
class WordLadderDataset(ProceduralDataset): class WordLadderDataset(ProceduralDataset):
"""Generates word ladder transformation tasks""" """Generates word ladder transformation tasks"""
@ -65,8 +67,7 @@ class WordLadderDataset(ProceduralDataset):
# Load words from CSV # Load words from CSV
self.word_sets = self._load_words_from_csv( self.word_sets = self._load_words_from_csv(
min_length=self.config.min_word_length, min_length=self.config.min_word_length, max_length=self.config.max_word_length
max_length=self.config.max_word_length
) )
# Precompute word graphs for all lengths # Precompute word graphs for all lengths
@ -76,7 +77,6 @@ class WordLadderDataset(ProceduralDataset):
config.validate() config.validate()
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
@classmethod @classmethod
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]: def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
"""Load words from CSV file organized by length""" """Load words from CSV file organized by length"""
@ -99,8 +99,8 @@ class WordLadderDataset(ProceduralDataset):
for row in reader: for row in reader:
# Process each word length column using config range # Process each word length column using config range
for length in range(min_length, max_length + 1): for length in range(min_length, max_length + 1):
col_name = f'{length}_letter' col_name = f"{length}_letter"
word = row.get(col_name, '') word = row.get(col_name, "")
if not word: # Skip empty entries if not word: # Skip empty entries
continue continue
@ -126,7 +126,7 @@ class WordLadderDataset(ProceduralDataset):
# Fall back to computing neighbors directly for custom word sets # Fall back to computing neighbors directly for custom word sets
neighbors = set() neighbors = set()
for i in range(len(word)): for i in range(len(word)):
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
neighbor = word[:i] + c + word[i + 1 :] neighbor = word[:i] + c + word[i + 1 :]
if neighbor != word and neighbor in word_set: if neighbor != word and neighbor in word_set:
neighbors.add(neighbor) neighbors.add(neighbor)
@ -146,7 +146,7 @@ class WordLadderDataset(ProceduralDataset):
for word in word_set: for word in word_set:
neighbors = set() neighbors = set()
for i in range(word_length): for i in range(word_length):
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
neighbor = word[:i] + c + word[i + 1 :] neighbor = word[:i] + c + word[i + 1 :]
if neighbor != word and neighbor in word_set: if neighbor != word and neighbor in word_set:
neighbors.add(neighbor) neighbors.add(neighbor)
@ -220,4 +220,5 @@ class WordLadderDataset(ProceduralDataset):
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
} }
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig) register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)

View file

@ -1,6 +1,7 @@
import pytest
from random import Random
import time import time
from random import Random
import pytest
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset
@ -70,7 +71,7 @@ def test_word_ladder_dataset_unique_pairs():
item = dataset[i] item = dataset[i]
pair = ( pair = (
min(item["metadata"]["start_word"], item["metadata"]["end_word"]), min(item["metadata"]["start_word"], item["metadata"]["end_word"]),
max(item["metadata"]["start_word"], item["metadata"]["end_word"]) max(item["metadata"]["start_word"], item["metadata"]["end_word"]),
) )
assert pair not in seen_pairs, f"Duplicate pair found: {pair}" assert pair not in seen_pairs, f"Duplicate pair found: {pair}"
seen_pairs.add(pair) seen_pairs.add(pair)
@ -80,23 +81,13 @@ def test_word_ladder_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
# Test with specific chain length constraints # Test with specific chain length constraints
config1 = WordLadderConfig( config1 = WordLadderConfig(
min_word_length=3, min_word_length=3, max_word_length=5, min_chain_length=3, max_chain_length=5, size=10, seed=42
max_word_length=5,
min_chain_length=3,
max_chain_length=5,
size=10,
seed=42
) )
dataset1 = WordLadderDataset(config1) dataset1 = WordLadderDataset(config1)
# Test with shortest path mode # Test with shortest path mode
config2 = WordLadderConfig( config2 = WordLadderConfig(
min_word_length=3, min_word_length=3, max_word_length=5, min_chain_length=-1, max_chain_length=-1, size=10, seed=42
max_word_length=5,
min_chain_length=-1,
max_chain_length=-1,
size=10,
seed=42
) )
dataset2 = WordLadderDataset(config2) dataset2 = WordLadderDataset(config2)
@ -171,7 +162,7 @@ def test_word_ladder_path_finding():
min_chain_length=-1, # Shortest path mode min_chain_length=-1, # Shortest path mode
max_chain_length=-1, max_chain_length=-1,
size=10, size=10,
seed=42 seed=42,
) )
dataset = WordLadderDataset(config) dataset = WordLadderDataset(config)
@ -194,10 +185,7 @@ def test_word_ladder_path_finding():
def test_word_ladder_csv_loading(): def test_word_ladder_csv_loading():
"""Test word loading from CSV""" """Test word loading from CSV"""
config = WordLadderConfig( config = WordLadderConfig(min_word_length=3, max_word_length=5)
min_word_length=3,
max_word_length=5
)
dataset = WordLadderDataset(config) dataset = WordLadderDataset(config)
# Verify word sets for each length # Verify word sets for each length
@ -219,12 +207,7 @@ def test_word_ladder_csv_loading():
def test_word_ladder_pair_generation(): def test_word_ladder_pair_generation():
"""Test word pair generation logic""" """Test word pair generation logic"""
config = WordLadderConfig( config = WordLadderConfig(min_word_length=4, max_word_length=4, size=10, seed=42)
min_word_length=4,
max_word_length=4,
size=10,
seed=42
)
dataset = WordLadderDataset(config) dataset = WordLadderDataset(config)
# Test pair generation # Test pair generation
@ -269,10 +252,7 @@ def test_word_graph_caching():
def test_word_ladder_path_validation(): def test_word_ladder_path_validation():
"""Test path length validation logic""" """Test path length validation logic"""
config = WordLadderConfig( config = WordLadderConfig(min_chain_length=4, max_chain_length=6)
min_chain_length=4,
max_chain_length=6
)
# Test specific length mode # Test specific length mode
assert config.is_valid_path_length(4) # Min length assert config.is_valid_path_length(4) # Min length
@ -282,20 +262,14 @@ def test_word_ladder_path_validation():
assert not config.is_valid_path_length(7) # Too long assert not config.is_valid_path_length(7) # Too long
# Test shortest path mode # Test shortest path mode
config_shortest = WordLadderConfig( config_shortest = WordLadderConfig(min_chain_length=-1, max_chain_length=-1)
min_chain_length=-1,
max_chain_length=-1
)
assert config_shortest.is_valid_path_length(3) assert config_shortest.is_valid_path_length(3)
assert config_shortest.is_valid_path_length(4) assert config_shortest.is_valid_path_length(4)
assert config_shortest.is_valid_path_length(10) assert config_shortest.is_valid_path_length(10)
assert not config_shortest.is_valid_path_length(2) assert not config_shortest.is_valid_path_length(2)
# Test mixed mode (shortest with max limit) # Test mixed mode (shortest with max limit)
config_mixed = WordLadderConfig( config_mixed = WordLadderConfig(min_chain_length=-1, max_chain_length=5)
min_chain_length=-1,
max_chain_length=5
)
assert config_mixed.is_valid_path_length(3) assert config_mixed.is_valid_path_length(3)
assert config_mixed.is_valid_path_length(4) assert config_mixed.is_valid_path_length(4)
assert config_mixed.is_valid_path_length(5) assert config_mixed.is_valid_path_length(5)
@ -305,12 +279,7 @@ def test_word_ladder_path_validation():
def test_word_ladder_solution_optimality(): def test_word_ladder_solution_optimality():
"""Test that generated solutions are optimal when min_chain_length=-1""" """Test that generated solutions are optimal when min_chain_length=-1"""
config = WordLadderConfig( config = WordLadderConfig(
min_word_length=4, min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=20, seed=42
max_word_length=4,
min_chain_length=-1,
max_chain_length=-1,
size=20,
seed=42
) )
dataset = WordLadderDataset(config) dataset = WordLadderDataset(config)
@ -325,6 +294,7 @@ def test_word_ladder_solution_optimality():
# Build graph and use BFS to find shortest path # Build graph and use BFS to find shortest path
from collections import deque from collections import deque
queue = deque([(start_word, [start_word])]) queue = deque([(start_word, [start_word])])
visited = {start_word} visited = {start_word}
shortest_path = None shortest_path = None
@ -341,8 +311,9 @@ def test_word_ladder_solution_optimality():
queue.append((neighbor, path + [neighbor])) queue.append((neighbor, path + [neighbor]))
assert shortest_path is not None, f"No path found between {start_word} and {end_word}" assert shortest_path is not None, f"No path found between {start_word} and {end_word}"
assert len(solution_chain) == len(shortest_path), \ assert len(solution_chain) == len(
f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}" shortest_path
), f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}"
def test_word_ladder_performance(): def test_word_ladder_performance():
@ -371,13 +342,7 @@ def test_word_ladder_edge_cases():
assert len(dataset) == 1 assert len(dataset) == 1
# Test with same start/end word length but maximum distance # Test with same start/end word length but maximum distance
config = WordLadderConfig( config = WordLadderConfig(min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=10)
min_word_length=4,
max_word_length=4,
min_chain_length=-1,
max_chain_length=-1,
size=10
)
dataset = WordLadderDataset(config) dataset = WordLadderDataset(config)
# Find the pair with longest solution # Find the pair with longest solution