mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
254 lines
10 KiB
Python
254 lines
10 KiB
Python
import datetime
|
|
import json
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from examples.word_ladder import generate_reasoning
|
|
|
|
# We alias the functions and globals for easier usage in our tests.
|
|
submit_reasoning_batches = generate_reasoning.submit_reasoning_batches
|
|
_submit_single_batch = generate_reasoning._submit_single_batch
|
|
DEFAULT_INPUT_JSONL = generate_reasoning.DEFAULT_INPUT_JSONL
|
|
COMMON_UUID = generate_reasoning.COMMON_UUID
|
|
BATCH_SIZE = generate_reasoning.BATCH_SIZE
|
|
client = generate_reasoning.client
|
|
|
|
|
|
# Define a mock batch response class mimicking Anthropic's API response.
|
|
class MockBatchResponse:
|
|
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
|
|
self.id = batch_id
|
|
self.type = "message_batch"
|
|
self.processing_status = processing_status
|
|
# Make request_counts a SimpleNamespace object with the required attributes
|
|
self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
|
|
self.ended_at = None
|
|
# Use datetime objects so that isoformat() is available
|
|
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
|
self.expires_at = self.created_at + datetime.timedelta(seconds=86400)
|
|
self.cancel_initiated_at = None
|
|
self.results_url = None
|
|
|
|
|
|
# Helper: Create a temporary system prompt file.
|
|
@pytest.fixture
|
|
def system_prompt_file(tmp_path, monkeypatch):
|
|
prompt_text = "This is a system prompt."
|
|
sys_file = tmp_path / "system_prompt.txt"
|
|
sys_file.write_text(prompt_text, encoding="utf-8")
|
|
|
|
# Monkeypatch the system prompt path
|
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
|
|
return sys_file
|
|
|
|
|
|
# Helper: Create necessary directories using a temporary location.
|
|
@pytest.fixture
|
|
def setup_directories(tmp_path, monkeypatch):
|
|
# Create output directory in temporary path
|
|
output_dir = tmp_path / "output"
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
# Monkeypatch the DEFAULT_OUTPUT_DIR to be a Path (temporary directory).
|
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_OUTPUT_DIR", output_dir)
|
|
|
|
# Ensure we're working in the temporary directory
|
|
monkeypatch.chdir(tmp_path)
|
|
return output_dir
|
|
|
|
|
|
# Helper: Create a temporary input JSONL file with given entries.
|
|
@pytest.fixture
|
|
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
|
|
# Create input file in temporary directory
|
|
file_path = setup_directories / "word_ladder_examples.jsonl"
|
|
entries = [
|
|
{
|
|
"question": "Transform 'A' to 'B'",
|
|
"answer": "A,X,B",
|
|
"reasoning": None,
|
|
"metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
|
|
},
|
|
{
|
|
"question": "Transform 'C' to 'D'",
|
|
"answer": "C,Y,D",
|
|
"reasoning": "Some reasoning",
|
|
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
|
|
},
|
|
{
|
|
"question": "Transform 'E' to 'F'",
|
|
"answer": "E,Z,F",
|
|
"reasoning": None,
|
|
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
|
|
},
|
|
]
|
|
with file_path.open("w", encoding="utf-8") as f:
|
|
for entry in entries:
|
|
f.write(json.dumps(entry) + "\n")
|
|
|
|
# Monkeypatch DEFAULT_INPUT_JSONL to point to our temporary test file.
|
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
|
|
return file_path
|
|
|
|
|
|
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
|
|
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
|
|
def fake_create(requests):
|
|
for req in requests:
|
|
# Handle the case where req is a dictionary
|
|
if isinstance(req, dict):
|
|
params = req.get("params", {})
|
|
custom_id = req.get("custom_id")
|
|
# Check if params itself is a dictionary
|
|
if isinstance(params, dict):
|
|
model = params.get("model")
|
|
temperature = params.get("temperature")
|
|
else:
|
|
model = params.model
|
|
temperature = params.temperature
|
|
else:
|
|
# Else, req is an object with attributes.
|
|
params = req.params
|
|
custom_id = req.custom_id
|
|
if isinstance(params, dict):
|
|
model = params.get("model")
|
|
temperature = params.get("temperature")
|
|
else:
|
|
model = params.model
|
|
temperature = params.temperature
|
|
assert model == "claude-3-5-sonnet-20241022", "Incorrect model version"
|
|
assert temperature == 0.5, "Incorrect temperature"
|
|
assert "C_D_" not in custom_id
|
|
return MockBatchResponse(batch_id="msgbatch_test_success")
|
|
|
|
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
|
|
|
batch_metadata_prefix = "test_metadata"
|
|
submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
|
|
|
|
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
|
meta_file_path = setup_directories / metadata_filename
|
|
assert meta_file_path.exists(), "Metadata file was not created as expected."
|
|
|
|
with meta_file_path.open("r", encoding="utf-8") as f:
|
|
lines = f.readlines()
|
|
# Expecting only those entries that did not already have a reasoning.
|
|
# (From our test input, 2 out of 3 entries qualify.)
|
|
assert len(lines) > 0
|
|
for line in lines:
|
|
metadata = json.loads(line)
|
|
api_response = metadata["api_response"]
|
|
assert api_response["id"] == "msgbatch_test_success"
|
|
assert api_response["processing_status"] == "in_progress"
|
|
custom_ids = metadata["custom_ids"]
|
|
assert len(custom_ids) == 2
|
|
|
|
|
|
# Test that _submit_single_batch retries once and eventually succeeds.
|
|
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
|
|
call_count = {"count": 0}
|
|
|
|
def fake_create_retry(requests):
|
|
if call_count["count"] == 0:
|
|
call_count["count"] += 1
|
|
raise Exception("Temporary failure")
|
|
return MockBatchResponse(batch_id="msgbatch_retry_success")
|
|
|
|
monkeypatch.setattr(client.messages.batches, "create", fake_create_retry)
|
|
|
|
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_1"})()
|
|
batch_requests = [dummy_request]
|
|
custom_ids = ["dummy_1"]
|
|
|
|
_submit_single_batch(batch_requests, custom_ids, 0, "test_retry", "dummy_input.jsonl")
|
|
|
|
metadata_filename = f"{COMMON_UUID}_test_retry.jsonl"
|
|
meta_file_path = setup_directories / metadata_filename
|
|
assert meta_file_path.exists(), "Retry metadata file was not created."
|
|
|
|
with meta_file_path.open("r", encoding="utf-8") as f:
|
|
metadata = json.loads(f.read())
|
|
assert metadata["api_response"]["id"] == "msgbatch_retry_success"
|
|
|
|
assert call_count["count"] == 1
|
|
|
|
|
|
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
|
|
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
|
|
def fake_create_fail(requests):
|
|
raise Exception("Permanent failure")
|
|
|
|
monkeypatch.setattr(client.messages.batches, "create", fake_create_fail)
|
|
|
|
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_fail"})()
|
|
batch_requests = [dummy_request]
|
|
custom_ids = ["dummy_fail"]
|
|
|
|
_submit_single_batch(batch_requests, custom_ids, 0, "test_failed", "dummy_input.jsonl")
|
|
|
|
error_filename = f"{COMMON_UUID}_failed_batches.jsonl"
|
|
error_file_path = setup_directories / error_filename
|
|
assert error_file_path.exists(), "Failed batch log file was not created."
|
|
|
|
with error_file_path.open("r", encoding="utf-8") as f:
|
|
error_entry = json.loads(f.readline())
|
|
assert error_entry["batch_number"] == 0
|
|
assert "Permanent failure" in error_entry["error"]
|
|
assert error_entry["batch_requests"] == ["dummy_fail"]
|
|
|
|
|
|
# Test batching behavior when multiple batches are needed.
|
|
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
|
test_batch_size = 2
|
|
monkeypatch.setattr(generate_reasoning, "BATCH_SIZE", test_batch_size)
|
|
|
|
# Create input file
|
|
input_file = setup_directories / "word_ladder_examples.jsonl"
|
|
entries = [
|
|
{
|
|
"question": f"Transform word ladder {idx}",
|
|
"answer": f"start,mid,end_{idx}",
|
|
"reasoning": None,
|
|
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
|
|
}
|
|
for idx in range(5)
|
|
]
|
|
|
|
with input_file.open("w", encoding="utf-8") as f:
|
|
for entry in entries:
|
|
f.write(json.dumps(entry) + "\n")
|
|
|
|
# Monkeypatch DEFAULT_INPUT_JSONL.
|
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
|
|
|
|
batch_ids = []
|
|
|
|
def fake_create(requests):
|
|
new_id = f"msgbatch_batch_{len(batch_ids)}"
|
|
batch_ids.append(new_id)
|
|
return MockBatchResponse(batch_id=new_id)
|
|
|
|
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
|
|
|
batch_metadata_prefix = "test_multi"
|
|
submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
|
|
|
|
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
|
meta_file_path = setup_directories / metadata_filename
|
|
assert meta_file_path.exists(), "Multiple batch metadata file was not created."
|
|
|
|
with meta_file_path.open("r", encoding="utf-8") as f:
|
|
metadata_lines = f.readlines()
|
|
# With 5 qualifying entries and a batch size of 2 we expect 3 batches.
|
|
assert len(metadata_lines) == 3
|
|
|
|
seen_custom_ids = []
|
|
for line in metadata_lines:
|
|
metadata = json.loads(line)
|
|
seen_custom_ids.extend(metadata["custom_ids"])
|
|
assert metadata["api_response"]["id"] in batch_ids
|
|
assert len(seen_custom_ids) == 5
|