reasoning-gym/examples/word_ladder/tests/test_generate_reasoning.py
Cavit Erginsoy aff0fecef4 lint
2025-02-03 11:35:30 +00:00

254 lines
10 KiB
Python

import datetime
import json
import os
import time
from pathlib import Path
from types import SimpleNamespace
import pytest
from examples.word_ladder import generate_reasoning
# We alias the functions and globals for easier usage in our tests.
submit_reasoning_batches = generate_reasoning.submit_reasoning_batches
_submit_single_batch = generate_reasoning._submit_single_batch
DEFAULT_INPUT_JSONL = generate_reasoning.DEFAULT_INPUT_JSONL
COMMON_UUID = generate_reasoning.COMMON_UUID
BATCH_SIZE = generate_reasoning.BATCH_SIZE
client = generate_reasoning.client
# Define a mock batch response class mimicking Anthropic's API response.
class MockBatchResponse:
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
self.id = batch_id
self.type = "message_batch"
self.processing_status = processing_status
# Make request_counts a SimpleNamespace object with the required attributes
self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
self.ended_at = None
# Use datetime objects so that isoformat() is available
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
self.expires_at = self.created_at + datetime.timedelta(seconds=86400)
self.cancel_initiated_at = None
self.results_url = None
# Helper: Create a temporary system prompt file.
@pytest.fixture
def system_prompt_file(tmp_path, monkeypatch):
prompt_text = "This is a system prompt."
sys_file = tmp_path / "system_prompt.txt"
sys_file.write_text(prompt_text, encoding="utf-8")
# Monkeypatch the system prompt path
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
return sys_file
# Helper: Create necessary directories using a temporary location.
@pytest.fixture
def setup_directories(tmp_path, monkeypatch):
# Create output directory in temporary path
output_dir = tmp_path / "output"
output_dir.mkdir(exist_ok=True)
# Monkeypatch the DEFAULT_OUTPUT_DIR to be a Path (temporary directory).
monkeypatch.setattr(generate_reasoning, "DEFAULT_OUTPUT_DIR", output_dir)
# Ensure we're working in the temporary directory
monkeypatch.chdir(tmp_path)
return output_dir
# Helper: Create a temporary input JSONL file with given entries.
@pytest.fixture
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
# Create input file in temporary directory
file_path = setup_directories / "word_ladder_examples.jsonl"
entries = [
{
"question": "Transform 'A' to 'B'",
"answer": "A,X,B",
"reasoning": None,
"metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
},
{
"question": "Transform 'C' to 'D'",
"answer": "C,Y,D",
"reasoning": "Some reasoning",
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
},
{
"question": "Transform 'E' to 'F'",
"answer": "E,Z,F",
"reasoning": None,
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
},
]
with file_path.open("w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
# Monkeypatch DEFAULT_INPUT_JSONL to point to our temporary test file.
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
return file_path
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
def fake_create(requests):
for req in requests:
# Handle the case where req is a dictionary
if isinstance(req, dict):
params = req.get("params", {})
custom_id = req.get("custom_id")
# Check if params itself is a dictionary
if isinstance(params, dict):
model = params.get("model")
temperature = params.get("temperature")
else:
model = params.model
temperature = params.temperature
else:
# Else, req is an object with attributes.
params = req.params
custom_id = req.custom_id
if isinstance(params, dict):
model = params.get("model")
temperature = params.get("temperature")
else:
model = params.model
temperature = params.temperature
assert model == "claude-3-5-sonnet-20241022", "Incorrect model version"
assert temperature == 0.5, "Incorrect temperature"
assert "C_D_" not in custom_id
return MockBatchResponse(batch_id="msgbatch_test_success")
monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_metadata"
submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Metadata file was not created as expected."
with meta_file_path.open("r", encoding="utf-8") as f:
lines = f.readlines()
# Expecting only those entries that did not already have a reasoning.
# (From our test input, 2 out of 3 entries qualify.)
assert len(lines) > 0
for line in lines:
metadata = json.loads(line)
api_response = metadata["api_response"]
assert api_response["id"] == "msgbatch_test_success"
assert api_response["processing_status"] == "in_progress"
custom_ids = metadata["custom_ids"]
assert len(custom_ids) == 2
# Test that _submit_single_batch retries once and eventually succeeds.
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
call_count = {"count": 0}
def fake_create_retry(requests):
if call_count["count"] == 0:
call_count["count"] += 1
raise Exception("Temporary failure")
return MockBatchResponse(batch_id="msgbatch_retry_success")
monkeypatch.setattr(client.messages.batches, "create", fake_create_retry)
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_1"})()
batch_requests = [dummy_request]
custom_ids = ["dummy_1"]
_submit_single_batch(batch_requests, custom_ids, 0, "test_retry", "dummy_input.jsonl")
metadata_filename = f"{COMMON_UUID}_test_retry.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Retry metadata file was not created."
with meta_file_path.open("r", encoding="utf-8") as f:
metadata = json.loads(f.read())
assert metadata["api_response"]["id"] == "msgbatch_retry_success"
assert call_count["count"] == 1
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
def fake_create_fail(requests):
raise Exception("Permanent failure")
monkeypatch.setattr(client.messages.batches, "create", fake_create_fail)
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_fail"})()
batch_requests = [dummy_request]
custom_ids = ["dummy_fail"]
_submit_single_batch(batch_requests, custom_ids, 0, "test_failed", "dummy_input.jsonl")
error_filename = f"{COMMON_UUID}_failed_batches.jsonl"
error_file_path = setup_directories / error_filename
assert error_file_path.exists(), "Failed batch log file was not created."
with error_file_path.open("r", encoding="utf-8") as f:
error_entry = json.loads(f.readline())
assert error_entry["batch_number"] == 0
assert "Permanent failure" in error_entry["error"]
assert error_entry["batch_requests"] == ["dummy_fail"]
# Test batching behavior when multiple batches are needed.
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
test_batch_size = 2
monkeypatch.setattr(generate_reasoning, "BATCH_SIZE", test_batch_size)
# Create input file
input_file = setup_directories / "word_ladder_examples.jsonl"
entries = [
{
"question": f"Transform word ladder {idx}",
"answer": f"start,mid,end_{idx}",
"reasoning": None,
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
}
for idx in range(5)
]
with input_file.open("w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
# Monkeypatch DEFAULT_INPUT_JSONL.
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
batch_ids = []
def fake_create(requests):
new_id = f"msgbatch_batch_{len(batch_ids)}"
batch_ids.append(new_id)
return MockBatchResponse(batch_id=new_id)
monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_multi"
submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Multiple batch metadata file was not created."
with meta_file_path.open("r", encoding="utf-8") as f:
metadata_lines = f.readlines()
# With 5 qualifying entries and a batch size of 2 we expect 3 batches.
assert len(metadata_lines) == 3
seen_custom_ids = []
for line in metadata_lines:
metadata = json.loads(line)
seen_custom_ids.extend(metadata["custom_ids"])
assert metadata["api_response"]["id"] in batch_ids
assert len(seen_custom_ids) == 5