reasoning-gym/examples/word_ladder/utils/generate_reasoning.py
Cavit Erginsoy aff0fecef4 lint
2025-02-03 11:35:30 +00:00

214 lines
9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
generate_reasoning.py Reads the JSONL file containing ladder examples,
creates batch requests of chain-of-thought prompts split into batches of 2,500,
calls Anthropic's Message Batches API for each batch, and writes separate batch metadata
files for later retrieval of the responses.
*** WARNING ***: Running large batches of requests via the Anthropic API (especially in generate_reasoning.py)
can incur significant costs in Anthropic credits. Please review and understand your API quota and budgeting
before running the API call. If you are testing or working with a demo dataset, adjust the batch size or dataset
size appropriately to avoid unexpected charges.
Using Anthropic's Message Batches API with caching enabled for system prompt.
In our informal testing, Sonnet was deemed best performance value.
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
"""
import json
import os
import time
import uuid
from pathlib import Path
import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
from tqdm import tqdm
# Updated default output directory to use the parent directory.
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
# Add default constants at the top with other constants
DEFAULT_INPUT_JSONL = "output/word_ladder_examples.jsonl"
DEFAULT_SYSTEM_PROMPT = Path(__file__).resolve().parent.parent / "system_prompt.txt"
BATCH_SIZE = 2500
COMMON_UUID = uuid.uuid4().hex[:8]
# Set up the Anthropic client (ensure the API key is set in the environment)
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def submit_reasoning_batches(
input_path: str = DEFAULT_INPUT_JSONL,
batch_metadata_prefix: str = "batch_metadata",
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
) -> None:
"""
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
Anthropic's Message Batches API.
Args:
input_path: Path to input JSONL file
batch_metadata_prefix: Prefix for batch metadata files
system_prompt_path: Path to system prompt file
"""
# Create output directory if it doesn't exist
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
# Read the system prompt from file (used as a preamble for every request)
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
system_message = [
{
"type": "text",
"text": sys_file.read(),
"cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
}
]
batch_requests = []
custom_ids = [] # List of custom_ids for the current batch
batch_num = 0
# Get the total number of lines in advance for tqdm progress bar.
total_lines = sum(1 for _ in open(input_path))
with open(input_path, "r", encoding="utf-8") as infile:
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
data = json.loads(line)
# Skip example if 'reasoning' already exists.
if not data.get("reasoning"):
# Build a custom id. Here we use the row position and the start/end words:
metadata = data.get("metadata", {})
start = metadata.get("start_word", "unknown")
end = metadata.get("end_word", "unknown")
custom_id = f"{start}_{end}_{idx}"
custom_ids.append(custom_id)
# Build the prompt text exactly as before.
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
# Build the request payload using Request and MessageCreateParamsNonStreaming.
request_payload = Request(
custom_id=custom_id,
params=MessageCreateParamsNonStreaming(
model="claude-3-5-sonnet-20241022", # Or choose the appropriate model version
max_tokens=8192,
temperature=0.5,
system=system_message,
messages=[{"role": "user", "content": prompt}],
),
)
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
if isinstance(request_payload, dict):
request_payload["custom_id"] = custom_id
else:
request_payload.custom_id = custom_id
batch_requests.append(request_payload)
# If we have reached our batch size limit, submit the current batch.
if len(batch_requests) >= BATCH_SIZE:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
batch_num += 1
# Reset for the next batch
batch_requests = []
custom_ids = []
# Submit any remaining requests that didn't complete a full batch.
if batch_requests:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path):
"""
Helper function to submit a single batch request, track the full API response,
and write out the corresponding metadata including the list of custom_ids.
"""
# Use the default output directory
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
def serialize_datetime(dt):
"""
Convert a datetime object to ISO formatted string.
If dt is None, returns None.
"""
if dt is None:
return None
iso_str = dt.isoformat() # e.g. "2024-08-20T18:37:24.100435+00:00"
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
iso_str = iso_str.replace("+00:00", "Z")
return iso_str
def extract_custom_id(req):
# Safely extract the custom_id attribute whether req is an object or a dict.
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
max_attempts = 2
attempt = 0
last_exception = None
message_batch = None
while attempt < max_attempts:
try:
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
message_batch = client.messages.batches.create(requests=batch_requests)
time.sleep(1)
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
break # Success: exit the loop.
except Exception as e:
last_exception = e
attempt += 1
print(f"Error submitting batch {batch_num} on attempt {attempt}: {e}")
if attempt < max_attempts:
print("Retrying...")
time.sleep(1)
if message_batch is None:
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
error_msg = (
f"{str(last_exception)} after {max_attempts} attempts"
if last_exception
else f"Failed after {max_attempts} attempts"
)
failed_info = {
"batch_number": batch_num,
"error": error_msg,
"batch_requests": [extract_custom_id(req) for req in batch_requests],
"input_file": input_path,
}
with open(error_filename, "a", encoding="utf-8") as error_file:
error_file.write(json.dumps(failed_info) + "\n")
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
return
# Build a dictionary of the expected response fields.
api_response = {
"id": message_batch.id,
"type": message_batch.type,
"processing_status": message_batch.processing_status,
"request_counts": vars(message_batch.request_counts),
"ended_at": serialize_datetime(message_batch.ended_at),
"created_at": serialize_datetime(message_batch.created_at),
"expires_at": serialize_datetime(message_batch.expires_at),
"cancel_initiated_at": serialize_datetime(message_batch.cancel_initiated_at),
"results_url": message_batch.results_url,
}
batch_metadata = {
"batch_id": message_batch.id,
"api_response": api_response,
"custom_ids": custom_ids,
"input_file": os.path.basename(input_path),
}
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
with open(metadata_filename, "a", encoding="utf-8") as meta_file:
meta_file.write(json.dumps(batch_metadata) + "\n")
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
if __name__ == "__main__":
# When running this module directly, submit the reasoning batches.
submit_reasoning_batches()