mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
lint
This commit is contained in:
parent
1e27021e11
commit
6c564b3dd9
13 changed files with 305 additions and 317 deletions
|
|
@ -129,4 +129,3 @@ The dataset generation parameters are centralized in `main.py` under the `config
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the MIT License.
|
This project is licensed under the MIT License.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ main.py – Orchestrates the overall flow:
|
||||||
3. Upload the final dataset to HuggingFace Hub (if needed)
|
3. Upload the final dataset to HuggingFace Hub (if needed)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
from examples.word_ladder.utils import create_word_ladders, generate_reasoning
|
from examples.word_ladder.utils import create_word_ladders, generate_reasoning
|
||||||
|
|
||||||
|
|
@ -40,17 +40,18 @@ def create_dataset(jsonl_path: Path, config: Dict[str, Any]) -> bool:
|
||||||
print(f"\nError: Failed to create dataset: {str(e)}")
|
print(f"\nError: Failed to create dataset: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Centralized configuration for the dataset
|
# Centralized configuration for the dataset
|
||||||
config = {
|
config = {
|
||||||
'dataset_name': 'word_ladder',
|
"dataset_name": "word_ladder",
|
||||||
'dataset_config': {
|
"dataset_config": {
|
||||||
'min_word_length': 3,
|
"min_word_length": 3,
|
||||||
'max_word_length': 5,
|
"max_word_length": 3,
|
||||||
'min_chain_length':-1, # set to -1 for the shortest possible path
|
"min_chain_length": -1, # set to -1 for the shortest possible path
|
||||||
'max_chain_length':10,
|
"max_chain_length": 7,
|
||||||
'size': 100, # Generate a small-ish dataset for demonstration
|
"size": 2000, # Generate a small-ish dataset for demonstration
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate a friendly unique identifier and compose the file path
|
# Generate a friendly unique identifier and compose the file path
|
||||||
|
|
@ -64,21 +65,20 @@ def main():
|
||||||
print("Exiting due to dataset creation failure.")
|
print("Exiting due to dataset creation failure.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# Step 2: Generate reasoning
|
# Step 2: Generate reasoning
|
||||||
'''
|
|
||||||
try:
|
try:
|
||||||
print("\nStep 2: Submitting reasoning batches for the dataset...")
|
print("\nStep 2: Submitting reasoning batches for the dataset...")
|
||||||
generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path))
|
generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError: Failed to submit reasoning batches: {str(e)}")
|
print(f"\nError: Failed to submit reasoning batches: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
'''
|
|
||||||
|
|
||||||
# Step 3: Check Anthropic batch results
|
# Step 3: Check Anthropic batch results
|
||||||
# Step 4: Upload to HuggingFace 🤗
|
# Step 4: Upload to HuggingFace 🤗
|
||||||
|
|
||||||
print("\nComplete!")
|
print("\nComplete!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -3,4 +3,3 @@
|
||||||
|
|
||||||
anthropic>=0.45.2 # Client library for interacting with Anthropic's API
|
anthropic>=0.45.2 # Client library for interacting with Anthropic's API
|
||||||
tqdm>=4.67.1 # For progress bars
|
tqdm>=4.67.1 # For progress bars
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import datetime
|
import datetime
|
||||||
import pytest
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from examples.word_ladder import generate_reasoning
|
from examples.word_ladder import generate_reasoning
|
||||||
|
|
||||||
# We alias the functions and globals for easier usage in our tests.
|
# We alias the functions and globals for easier usage in our tests.
|
||||||
|
|
@ -16,6 +17,7 @@ COMMON_UUID = generate_reasoning.COMMON_UUID
|
||||||
BATCH_SIZE = generate_reasoning.BATCH_SIZE
|
BATCH_SIZE = generate_reasoning.BATCH_SIZE
|
||||||
client = generate_reasoning.client
|
client = generate_reasoning.client
|
||||||
|
|
||||||
|
|
||||||
# Define a mock batch response class mimicking Anthropic's API response.
|
# Define a mock batch response class mimicking Anthropic's API response.
|
||||||
class MockBatchResponse:
|
class MockBatchResponse:
|
||||||
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
|
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
|
||||||
|
|
@ -23,13 +25,7 @@ class MockBatchResponse:
|
||||||
self.type = "message_batch"
|
self.type = "message_batch"
|
||||||
self.processing_status = processing_status
|
self.processing_status = processing_status
|
||||||
# Make request_counts a SimpleNamespace object with the required attributes
|
# Make request_counts a SimpleNamespace object with the required attributes
|
||||||
self.request_counts = SimpleNamespace(
|
self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
|
||||||
processing=0,
|
|
||||||
succeeded=0,
|
|
||||||
errored=0,
|
|
||||||
canceled=0,
|
|
||||||
expired=0
|
|
||||||
)
|
|
||||||
self.ended_at = None
|
self.ended_at = None
|
||||||
# Use datetime objects so that isoformat() is available
|
# Use datetime objects so that isoformat() is available
|
||||||
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||||
|
|
@ -37,6 +33,7 @@ class MockBatchResponse:
|
||||||
self.cancel_initiated_at = None
|
self.cancel_initiated_at = None
|
||||||
self.results_url = None
|
self.results_url = None
|
||||||
|
|
||||||
|
|
||||||
# Helper: Create a temporary system prompt file.
|
# Helper: Create a temporary system prompt file.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def system_prompt_file(tmp_path, monkeypatch):
|
def system_prompt_file(tmp_path, monkeypatch):
|
||||||
|
|
@ -48,6 +45,7 @@ def system_prompt_file(tmp_path, monkeypatch):
|
||||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
|
||||||
return sys_file
|
return sys_file
|
||||||
|
|
||||||
|
|
||||||
# Helper: Create necessary directories using a temporary location.
|
# Helper: Create necessary directories using a temporary location.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_directories(tmp_path, monkeypatch):
|
def setup_directories(tmp_path, monkeypatch):
|
||||||
|
|
@ -62,18 +60,31 @@ def setup_directories(tmp_path, monkeypatch):
|
||||||
monkeypatch.chdir(tmp_path)
|
monkeypatch.chdir(tmp_path)
|
||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
# Helper: Create a temporary input JSONL file with given entries.
|
# Helper: Create a temporary input JSONL file with given entries.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
|
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
|
||||||
# Create input file in temporary directory
|
# Create input file in temporary directory
|
||||||
file_path = setup_directories / "word_ladder_examples.jsonl"
|
file_path = setup_directories / "word_ladder_examples.jsonl"
|
||||||
entries = [
|
entries = [
|
||||||
{ "question": "Transform 'A' to 'B'", "answer": "A,X,B", "reasoning": None,
|
{
|
||||||
"metadata": { "start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3 } },
|
"question": "Transform 'A' to 'B'",
|
||||||
{ "question": "Transform 'C' to 'D'", "answer": "C,Y,D", "reasoning": "Some reasoning",
|
"answer": "A,X,B",
|
||||||
"metadata": { "start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3 } },
|
"reasoning": None,
|
||||||
{ "question": "Transform 'E' to 'F'", "answer": "E,Z,F", "reasoning": None,
|
"metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
|
||||||
"metadata": { "start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3 } }
|
},
|
||||||
|
{
|
||||||
|
"question": "Transform 'C' to 'D'",
|
||||||
|
"answer": "C,Y,D",
|
||||||
|
"reasoning": "Some reasoning",
|
||||||
|
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Transform 'E' to 'F'",
|
||||||
|
"answer": "E,Z,F",
|
||||||
|
"reasoning": None,
|
||||||
|
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
with file_path.open("w", encoding="utf-8") as f:
|
with file_path.open("w", encoding="utf-8") as f:
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
|
@ -83,6 +94,7 @@ def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
|
||||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
|
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
|
||||||
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
|
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
|
||||||
def fake_create(requests):
|
def fake_create(requests):
|
||||||
|
|
@ -116,8 +128,7 @@ def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_dire
|
||||||
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
||||||
|
|
||||||
batch_metadata_prefix = "test_metadata"
|
batch_metadata_prefix = "test_metadata"
|
||||||
submit_reasoning_batches(input_path=str(input_jsonl_file),
|
submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
|
||||||
batch_metadata_prefix=batch_metadata_prefix)
|
|
||||||
|
|
||||||
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||||
meta_file_path = setup_directories / metadata_filename
|
meta_file_path = setup_directories / metadata_filename
|
||||||
|
|
@ -136,9 +147,11 @@ def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_dire
|
||||||
custom_ids = metadata["custom_ids"]
|
custom_ids = metadata["custom_ids"]
|
||||||
assert len(custom_ids) == 2
|
assert len(custom_ids) == 2
|
||||||
|
|
||||||
|
|
||||||
# Test that _submit_single_batch retries once and eventually succeeds.
|
# Test that _submit_single_batch retries once and eventually succeeds.
|
||||||
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
|
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
|
||||||
call_count = {"count": 0}
|
call_count = {"count": 0}
|
||||||
|
|
||||||
def fake_create_retry(requests):
|
def fake_create_retry(requests):
|
||||||
if call_count["count"] == 0:
|
if call_count["count"] == 0:
|
||||||
call_count["count"] += 1
|
call_count["count"] += 1
|
||||||
|
|
@ -163,6 +176,7 @@ def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
|
||||||
|
|
||||||
assert call_count["count"] == 1
|
assert call_count["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
|
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
|
||||||
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
|
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
|
||||||
def fake_create_fail(requests):
|
def fake_create_fail(requests):
|
||||||
|
|
@ -186,6 +200,7 @@ def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
|
||||||
assert "Permanent failure" in error_entry["error"]
|
assert "Permanent failure" in error_entry["error"]
|
||||||
assert error_entry["batch_requests"] == ["dummy_fail"]
|
assert error_entry["batch_requests"] == ["dummy_fail"]
|
||||||
|
|
||||||
|
|
||||||
# Test batching behavior when multiple batches are needed.
|
# Test batching behavior when multiple batches are needed.
|
||||||
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
||||||
test_batch_size = 2
|
test_batch_size = 2
|
||||||
|
|
@ -198,7 +213,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
||||||
"question": f"Transform word ladder {idx}",
|
"question": f"Transform word ladder {idx}",
|
||||||
"answer": f"start,mid,end_{idx}",
|
"answer": f"start,mid,end_{idx}",
|
||||||
"reasoning": None,
|
"reasoning": None,
|
||||||
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"}
|
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
|
||||||
}
|
}
|
||||||
for idx in range(5)
|
for idx in range(5)
|
||||||
]
|
]
|
||||||
|
|
@ -211,6 +226,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
||||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
|
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
|
||||||
|
|
||||||
batch_ids = []
|
batch_ids = []
|
||||||
|
|
||||||
def fake_create(requests):
|
def fake_create(requests):
|
||||||
new_id = f"msgbatch_batch_{len(batch_ids)}"
|
new_id = f"msgbatch_batch_{len(batch_ids)}"
|
||||||
batch_ids.append(new_id)
|
batch_ids.append(new_id)
|
||||||
|
|
@ -219,8 +235,7 @@ def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
||||||
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
||||||
|
|
||||||
batch_metadata_prefix = "test_multi"
|
batch_metadata_prefix = "test_multi"
|
||||||
submit_reasoning_batches(input_path=str(input_file),
|
submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
|
||||||
batch_metadata_prefix=batch_metadata_prefix)
|
|
||||||
|
|
||||||
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||||
meta_file_path = setup_directories / metadata_filename
|
meta_file_path = setup_directories / metadata_filename
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
import reasoning_gym
|
import reasoning_gym
|
||||||
|
|
||||||
|
|
||||||
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
||||||
"""
|
"""
|
||||||
Check for duplicate word pairs in a word ladder JSONL file.
|
Check for duplicate word pairs in a word ladder JSONL file.
|
||||||
|
|
@ -27,12 +28,12 @@ def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
||||||
valid_entries = {}
|
valid_entries = {}
|
||||||
duplicates_found = False
|
duplicates_found = False
|
||||||
|
|
||||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
with open(jsonl_path, "r", encoding="utf-8") as f:
|
||||||
for line_num, line in enumerate(f):
|
for line_num, line in enumerate(f):
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
metadata = data['metadata']
|
metadata = data["metadata"]
|
||||||
pair = (metadata['start_word'], metadata['end_word'])
|
pair = (metadata["start_word"], metadata["end_word"])
|
||||||
reverse_pair = (metadata['end_word'], metadata['start_word'])
|
reverse_pair = (metadata["end_word"], metadata["start_word"])
|
||||||
|
|
||||||
# Check both orientations of the pair
|
# Check both orientations of the pair
|
||||||
if pair in pairs_seen or reverse_pair in pairs_seen:
|
if pair in pairs_seen or reverse_pair in pairs_seen:
|
||||||
|
|
@ -46,6 +47,7 @@ def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
||||||
|
|
||||||
return duplicates_found, valid_entries
|
return duplicates_found, valid_entries
|
||||||
|
|
||||||
|
|
||||||
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
|
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
|
||||||
"""
|
"""
|
||||||
Creates a word ladder dataset and writes each sample as a JSON line.
|
Creates a word ladder dataset and writes each sample as a JSON line.
|
||||||
|
|
@ -66,22 +68,22 @@ def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> N
|
||||||
else:
|
else:
|
||||||
jsonl_path = Path(jsonl_path)
|
jsonl_path = Path(jsonl_path)
|
||||||
|
|
||||||
target_size = config['dataset_config']['size']
|
target_size = config["dataset_config"]["size"]
|
||||||
current_size = 0
|
current_size = 0
|
||||||
max_attempts = 3 # Limit total regeneration attempts
|
max_attempts = 3 # Limit total regeneration attempts
|
||||||
attempt = 0
|
attempt = 0
|
||||||
|
|
||||||
# Initial generation
|
# Initial generation
|
||||||
dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
|
dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
|
||||||
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
for item in tqdm(dataset, desc="Generating initial ladder examples"):
|
for item in tqdm(dataset, desc="Generating initial ladder examples"):
|
||||||
row = {
|
row = {
|
||||||
'question': item['question'],
|
"question": item["question"],
|
||||||
'answer': item['answer'],
|
"answer": item["answer"],
|
||||||
'reasoning': None,
|
"reasoning": None,
|
||||||
'metadata': item.get('metadata', {})
|
"metadata": item.get("metadata", {}),
|
||||||
}
|
}
|
||||||
f.write(json.dumps(row) + '\n')
|
f.write(json.dumps(row) + "\n")
|
||||||
|
|
||||||
while attempt < max_attempts:
|
while attempt < max_attempts:
|
||||||
# Check entire file for duplicates
|
# Check entire file for duplicates
|
||||||
|
|
@ -98,25 +100,25 @@ def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> N
|
||||||
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
|
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
|
||||||
|
|
||||||
# Generate additional examples
|
# Generate additional examples
|
||||||
config['dataset_config']['size'] = needed
|
config["dataset_config"]["size"] = needed
|
||||||
additional_dataset = reasoning_gym.create_dataset(config['dataset_name'], **config['dataset_config'])
|
additional_dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
|
||||||
|
|
||||||
# Write all entries to a temporary file
|
# Write all entries to a temporary file
|
||||||
temp_path = jsonl_path.with_suffix('.tmp')
|
temp_path = jsonl_path.with_suffix(".tmp")
|
||||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
with open(temp_path, "w", encoding="utf-8") as f:
|
||||||
# Write existing valid entries
|
# Write existing valid entries
|
||||||
for data in valid_entries.values():
|
for data in valid_entries.values():
|
||||||
f.write(json.dumps(data) + '\n')
|
f.write(json.dumps(data) + "\n")
|
||||||
|
|
||||||
# Write new entries
|
# Write new entries
|
||||||
for item in additional_dataset:
|
for item in additional_dataset:
|
||||||
row = {
|
row = {
|
||||||
'question': item['question'],
|
"question": item["question"],
|
||||||
'answer': item['answer'],
|
"answer": item["answer"],
|
||||||
'reasoning': None,
|
"reasoning": None,
|
||||||
'metadata': item.get('metadata', {})
|
"metadata": item.get("metadata", {}),
|
||||||
}
|
}
|
||||||
f.write(json.dumps(row) + '\n')
|
f.write(json.dumps(row) + "\n")
|
||||||
|
|
||||||
# Replace original file with temporary file
|
# Replace original file with temporary file
|
||||||
temp_path.replace(jsonl_path)
|
temp_path.replace(jsonl_path)
|
||||||
|
|
|
||||||
|
|
@ -14,17 +14,16 @@ In our informal testing, Sonnet was deemed best performance value.
|
||||||
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
|
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import os
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
||||||
from anthropic.types.messages.batch_create_params import Request
|
from anthropic.types.messages.batch_create_params import Request
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Updated default output directory to use the parent directory.
|
# Updated default output directory to use the parent directory.
|
||||||
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
|
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
|
||||||
|
|
@ -36,12 +35,13 @@ BATCH_SIZE = 2500
|
||||||
COMMON_UUID = uuid.uuid4().hex[:8]
|
COMMON_UUID = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
# Set up the Anthropic client (ensure the API key is set in the environment)
|
# Set up the Anthropic client (ensure the API key is set in the environment)
|
||||||
client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY'])
|
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
||||||
|
|
||||||
|
|
||||||
def submit_reasoning_batches(
|
def submit_reasoning_batches(
|
||||||
input_path: str = DEFAULT_INPUT_JSONL,
|
input_path: str = DEFAULT_INPUT_JSONL,
|
||||||
batch_metadata_prefix: str = "batch_metadata",
|
batch_metadata_prefix: str = "batch_metadata",
|
||||||
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT
|
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
|
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
|
||||||
|
|
@ -59,11 +59,13 @@ def submit_reasoning_batches(
|
||||||
|
|
||||||
# Read the system prompt from file (used as a preamble for every request)
|
# Read the system prompt from file (used as a preamble for every request)
|
||||||
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
|
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
|
||||||
system_message = [{
|
system_message = [
|
||||||
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": sys_file.read(),
|
"text": sys_file.read(),
|
||||||
"cache_control": {"type": "ephemeral"} # Enable anthropic prompt caching
|
"cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
batch_requests = []
|
batch_requests = []
|
||||||
custom_ids = [] # List of custom_ids for the current batch
|
custom_ids = [] # List of custom_ids for the current batch
|
||||||
batch_num = 0
|
batch_num = 0
|
||||||
|
|
@ -71,12 +73,12 @@ def submit_reasoning_batches(
|
||||||
# Get the total number of lines in advance for tqdm progress bar.
|
# Get the total number of lines in advance for tqdm progress bar.
|
||||||
total_lines = sum(1 for _ in open(input_path))
|
total_lines = sum(1 for _ in open(input_path))
|
||||||
|
|
||||||
with open(input_path, 'r', encoding="utf-8") as infile:
|
with open(input_path, "r", encoding="utf-8") as infile:
|
||||||
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
|
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
|
|
||||||
# Skip example if 'reasoning' already exists.
|
# Skip example if 'reasoning' already exists.
|
||||||
if not data.get('reasoning'):
|
if not data.get("reasoning"):
|
||||||
# Build a custom id. Here we use the row position and the start/end words:
|
# Build a custom id. Here we use the row position and the start/end words:
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
start = metadata.get("start_word", "unknown")
|
start = metadata.get("start_word", "unknown")
|
||||||
|
|
@ -95,10 +97,8 @@ def submit_reasoning_batches(
|
||||||
max_tokens=8192,
|
max_tokens=8192,
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
system=system_message,
|
system=system_message,
|
||||||
messages=[
|
messages=[{"role": "user", "content": prompt}],
|
||||||
{"role": "user", "content": prompt}
|
),
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
|
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
|
||||||
if isinstance(request_payload, dict):
|
if isinstance(request_payload, dict):
|
||||||
|
|
@ -153,9 +153,7 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
|
||||||
while attempt < max_attempts:
|
while attempt < max_attempts:
|
||||||
try:
|
try:
|
||||||
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
|
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
|
||||||
message_batch = client.messages.batches.create(
|
message_batch = client.messages.batches.create(requests=batch_requests)
|
||||||
requests=batch_requests
|
|
||||||
)
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
|
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
|
||||||
break # Success: exit the loop.
|
break # Success: exit the loop.
|
||||||
|
|
@ -169,14 +167,18 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
|
||||||
|
|
||||||
if message_batch is None:
|
if message_batch is None:
|
||||||
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
|
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
|
||||||
error_msg = f"{str(last_exception)} after {max_attempts} attempts" if last_exception else f"Failed after {max_attempts} attempts"
|
error_msg = (
|
||||||
|
f"{str(last_exception)} after {max_attempts} attempts"
|
||||||
|
if last_exception
|
||||||
|
else f"Failed after {max_attempts} attempts"
|
||||||
|
)
|
||||||
failed_info = {
|
failed_info = {
|
||||||
"batch_number": batch_num,
|
"batch_number": batch_num,
|
||||||
"error": error_msg,
|
"error": error_msg,
|
||||||
"batch_requests": [extract_custom_id(req) for req in batch_requests],
|
"batch_requests": [extract_custom_id(req) for req in batch_requests],
|
||||||
"input_file": input_path,
|
"input_file": input_path,
|
||||||
}
|
}
|
||||||
with open(error_filename, 'a', encoding='utf-8') as error_file:
|
with open(error_filename, "a", encoding="utf-8") as error_file:
|
||||||
error_file.write(json.dumps(failed_info) + "\n")
|
error_file.write(json.dumps(failed_info) + "\n")
|
||||||
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
|
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
|
||||||
return
|
return
|
||||||
|
|
@ -201,11 +203,12 @@ def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_p
|
||||||
"input_file": os.path.basename(input_path),
|
"input_file": os.path.basename(input_path),
|
||||||
}
|
}
|
||||||
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||||
with open(metadata_filename, 'a', encoding='utf-8') as meta_file:
|
with open(metadata_filename, "a", encoding="utf-8") as meta_file:
|
||||||
meta_file.write(json.dumps(batch_metadata) + "\n")
|
meta_file.write(json.dumps(batch_metadata) + "\n")
|
||||||
|
|
||||||
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
|
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# When running this module directly, submit the reasoning batches.
|
# When running this module directly, submit the reasoning batches.
|
||||||
submit_reasoning_batches()
|
submit_reasoning_batches()
|
||||||
|
|
|
||||||
|
|
@ -17,17 +17,14 @@ Usage:
|
||||||
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
|
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Compute usage token statistics from a JSONL file.")
|
||||||
description="Compute usage token statistics from a JSONL file."
|
parser.add_argument("file", help="Path to the JSONL file containing usage token data.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"file", help="Path to the JSONL file containing usage token data."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Usage token fields that we want to track
|
# Usage token fields that we want to track
|
||||||
|
|
@ -129,9 +126,11 @@ def main():
|
||||||
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
|
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
|
||||||
|
|
||||||
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
|
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||||
actual_input_cost = (sum_input) / 1_000_000 * pricing["input_tokens"] \
|
actual_input_cost = (
|
||||||
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
|
(sum_input) / 1_000_000 * pricing["input_tokens"]
|
||||||
|
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
|
||||||
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||||
|
)
|
||||||
caching_savings = baseline_input_cost - actual_input_cost
|
caching_savings = baseline_input_cost - actual_input_cost
|
||||||
|
|
||||||
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
|
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
|
||||||
|
|
@ -172,12 +171,16 @@ def main():
|
||||||
forecast_output = avg_output_tokens * jobs
|
forecast_output = avg_output_tokens * jobs
|
||||||
|
|
||||||
# Forecast actual cost (with caching applied for input tokens):
|
# Forecast actual cost (with caching applied for input tokens):
|
||||||
actual_input_cost_forecast = (forecast_input) / 1_000_000 * pricing["input_tokens"] \
|
actual_input_cost_forecast = (
|
||||||
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"] \
|
(forecast_input) / 1_000_000 * pricing["input_tokens"]
|
||||||
|
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
|
||||||
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
# Without caching, all input-related tokens would be at base_input_rate:
|
# Without caching, all input-related tokens would be at base_input_rate:
|
||||||
baseline_input_cost_forecast = (forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
|
baseline_input_cost_forecast = (
|
||||||
|
(forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
|
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
|
||||||
|
|
||||||
|
|
@ -198,5 +201,6 @@ def main():
|
||||||
else:
|
else:
|
||||||
print("No valid jobs to forecast future costs.")
|
print("No valid jobs to forecast future costs.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -52,8 +52,10 @@ class WordLadderConfig:
|
||||||
return 3 <= length <= self.max_chain_length
|
return 3 <= length <= self.max_chain_length
|
||||||
|
|
||||||
# Otherwise check against both min and max
|
# Otherwise check against both min and max
|
||||||
return (self.min_chain_length <= length <=
|
return (
|
||||||
(self.max_chain_length if self.max_chain_length != -1 else float('inf')))
|
self.min_chain_length <= length <= (self.max_chain_length if self.max_chain_length != -1 else float("inf"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WordLadderDataset(ProceduralDataset):
|
class WordLadderDataset(ProceduralDataset):
|
||||||
"""Generates word ladder transformation tasks"""
|
"""Generates word ladder transformation tasks"""
|
||||||
|
|
@ -65,8 +67,7 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
|
|
||||||
# Load words from CSV
|
# Load words from CSV
|
||||||
self.word_sets = self._load_words_from_csv(
|
self.word_sets = self._load_words_from_csv(
|
||||||
min_length=self.config.min_word_length,
|
min_length=self.config.min_word_length, max_length=self.config.max_word_length
|
||||||
max_length=self.config.max_word_length
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Precompute word graphs for all lengths
|
# Precompute word graphs for all lengths
|
||||||
|
|
@ -76,7 +77,6 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
config.validate()
|
config.validate()
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
|
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
|
||||||
"""Load words from CSV file organized by length"""
|
"""Load words from CSV file organized by length"""
|
||||||
|
|
@ -99,8 +99,8 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
for row in reader:
|
for row in reader:
|
||||||
# Process each word length column using config range
|
# Process each word length column using config range
|
||||||
for length in range(min_length, max_length + 1):
|
for length in range(min_length, max_length + 1):
|
||||||
col_name = f'{length}_letter'
|
col_name = f"{length}_letter"
|
||||||
word = row.get(col_name, '')
|
word = row.get(col_name, "")
|
||||||
|
|
||||||
if not word: # Skip empty entries
|
if not word: # Skip empty entries
|
||||||
continue
|
continue
|
||||||
|
|
@ -126,8 +126,8 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
# Fall back to computing neighbors directly for custom word sets
|
# Fall back to computing neighbors directly for custom word sets
|
||||||
neighbors = set()
|
neighbors = set()
|
||||||
for i in range(len(word)):
|
for i in range(len(word)):
|
||||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
|
||||||
neighbor = word[:i] + c + word[i+1:]
|
neighbor = word[:i] + c + word[i + 1 :]
|
||||||
if neighbor != word and neighbor in word_set:
|
if neighbor != word and neighbor in word_set:
|
||||||
neighbors.add(neighbor)
|
neighbors.add(neighbor)
|
||||||
return neighbors
|
return neighbors
|
||||||
|
|
@ -146,8 +146,8 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
for word in word_set:
|
for word in word_set:
|
||||||
neighbors = set()
|
neighbors = set()
|
||||||
for i in range(word_length):
|
for i in range(word_length):
|
||||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
|
||||||
neighbor = word[:i] + c + word[i+1:]
|
neighbor = word[:i] + c + word[i + 1 :]
|
||||||
if neighbor != word and neighbor in word_set:
|
if neighbor != word and neighbor in word_set:
|
||||||
neighbors.add(neighbor)
|
neighbors.add(neighbor)
|
||||||
graph[word] = neighbors
|
graph[word] = neighbors
|
||||||
|
|
@ -220,4 +220,5 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
|
||||||
from random import Random
|
|
||||||
import time
|
import time
|
||||||
|
from random import Random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset
|
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset
|
||||||
|
|
||||||
|
|
@ -70,7 +71,7 @@ def test_word_ladder_dataset_unique_pairs():
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
pair = (
|
pair = (
|
||||||
min(item["metadata"]["start_word"], item["metadata"]["end_word"]),
|
min(item["metadata"]["start_word"], item["metadata"]["end_word"]),
|
||||||
max(item["metadata"]["start_word"], item["metadata"]["end_word"])
|
max(item["metadata"]["start_word"], item["metadata"]["end_word"]),
|
||||||
)
|
)
|
||||||
assert pair not in seen_pairs, f"Duplicate pair found: {pair}"
|
assert pair not in seen_pairs, f"Duplicate pair found: {pair}"
|
||||||
seen_pairs.add(pair)
|
seen_pairs.add(pair)
|
||||||
|
|
@ -80,23 +81,13 @@ def test_word_ladder_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
# Test with specific chain length constraints
|
# Test with specific chain length constraints
|
||||||
config1 = WordLadderConfig(
|
config1 = WordLadderConfig(
|
||||||
min_word_length=3,
|
min_word_length=3, max_word_length=5, min_chain_length=3, max_chain_length=5, size=10, seed=42
|
||||||
max_word_length=5,
|
|
||||||
min_chain_length=3,
|
|
||||||
max_chain_length=5,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset1 = WordLadderDataset(config1)
|
dataset1 = WordLadderDataset(config1)
|
||||||
|
|
||||||
# Test with shortest path mode
|
# Test with shortest path mode
|
||||||
config2 = WordLadderConfig(
|
config2 = WordLadderConfig(
|
||||||
min_word_length=3,
|
min_word_length=3, max_word_length=5, min_chain_length=-1, max_chain_length=-1, size=10, seed=42
|
||||||
max_word_length=5,
|
|
||||||
min_chain_length=-1,
|
|
||||||
max_chain_length=-1,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset2 = WordLadderDataset(config2)
|
dataset2 = WordLadderDataset(config2)
|
||||||
|
|
||||||
|
|
@ -171,7 +162,7 @@ def test_word_ladder_path_finding():
|
||||||
min_chain_length=-1, # Shortest path mode
|
min_chain_length=-1, # Shortest path mode
|
||||||
max_chain_length=-1,
|
max_chain_length=-1,
|
||||||
size=10,
|
size=10,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = WordLadderDataset(config)
|
dataset = WordLadderDataset(config)
|
||||||
|
|
||||||
|
|
@ -186,18 +177,15 @@ def test_word_ladder_path_finding():
|
||||||
assert len(path) >= 3
|
assert len(path) >= 3
|
||||||
|
|
||||||
# Verify each step differs by only one letter
|
# Verify each step differs by only one letter
|
||||||
for i in range(len(path)-1):
|
for i in range(len(path) - 1):
|
||||||
current = path[i]
|
current = path[i]
|
||||||
next_word = path[i+1]
|
next_word = path[i + 1]
|
||||||
assert next_word in dataset._get_neighbors(current, word_set)
|
assert next_word in dataset._get_neighbors(current, word_set)
|
||||||
|
|
||||||
|
|
||||||
def test_word_ladder_csv_loading():
|
def test_word_ladder_csv_loading():
|
||||||
"""Test word loading from CSV"""
|
"""Test word loading from CSV"""
|
||||||
config = WordLadderConfig(
|
config = WordLadderConfig(min_word_length=3, max_word_length=5)
|
||||||
min_word_length=3,
|
|
||||||
max_word_length=5
|
|
||||||
)
|
|
||||||
dataset = WordLadderDataset(config)
|
dataset = WordLadderDataset(config)
|
||||||
|
|
||||||
# Verify word sets for each length
|
# Verify word sets for each length
|
||||||
|
|
@ -219,12 +207,7 @@ def test_word_ladder_csv_loading():
|
||||||
|
|
||||||
def test_word_ladder_pair_generation():
|
def test_word_ladder_pair_generation():
|
||||||
"""Test word pair generation logic"""
|
"""Test word pair generation logic"""
|
||||||
config = WordLadderConfig(
|
config = WordLadderConfig(min_word_length=4, max_word_length=4, size=10, seed=42)
|
||||||
min_word_length=4,
|
|
||||||
max_word_length=4,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = WordLadderDataset(config)
|
dataset = WordLadderDataset(config)
|
||||||
|
|
||||||
# Test pair generation
|
# Test pair generation
|
||||||
|
|
@ -269,10 +252,7 @@ def test_word_graph_caching():
|
||||||
|
|
||||||
def test_word_ladder_path_validation():
|
def test_word_ladder_path_validation():
|
||||||
"""Test path length validation logic"""
|
"""Test path length validation logic"""
|
||||||
config = WordLadderConfig(
|
config = WordLadderConfig(min_chain_length=4, max_chain_length=6)
|
||||||
min_chain_length=4,
|
|
||||||
max_chain_length=6
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test specific length mode
|
# Test specific length mode
|
||||||
assert config.is_valid_path_length(4) # Min length
|
assert config.is_valid_path_length(4) # Min length
|
||||||
|
|
@ -282,20 +262,14 @@ def test_word_ladder_path_validation():
|
||||||
assert not config.is_valid_path_length(7) # Too long
|
assert not config.is_valid_path_length(7) # Too long
|
||||||
|
|
||||||
# Test shortest path mode
|
# Test shortest path mode
|
||||||
config_shortest = WordLadderConfig(
|
config_shortest = WordLadderConfig(min_chain_length=-1, max_chain_length=-1)
|
||||||
min_chain_length=-1,
|
|
||||||
max_chain_length=-1
|
|
||||||
)
|
|
||||||
assert config_shortest.is_valid_path_length(3)
|
assert config_shortest.is_valid_path_length(3)
|
||||||
assert config_shortest.is_valid_path_length(4)
|
assert config_shortest.is_valid_path_length(4)
|
||||||
assert config_shortest.is_valid_path_length(10)
|
assert config_shortest.is_valid_path_length(10)
|
||||||
assert not config_shortest.is_valid_path_length(2)
|
assert not config_shortest.is_valid_path_length(2)
|
||||||
|
|
||||||
# Test mixed mode (shortest with max limit)
|
# Test mixed mode (shortest with max limit)
|
||||||
config_mixed = WordLadderConfig(
|
config_mixed = WordLadderConfig(min_chain_length=-1, max_chain_length=5)
|
||||||
min_chain_length=-1,
|
|
||||||
max_chain_length=5
|
|
||||||
)
|
|
||||||
assert config_mixed.is_valid_path_length(3)
|
assert config_mixed.is_valid_path_length(3)
|
||||||
assert config_mixed.is_valid_path_length(4)
|
assert config_mixed.is_valid_path_length(4)
|
||||||
assert config_mixed.is_valid_path_length(5)
|
assert config_mixed.is_valid_path_length(5)
|
||||||
|
|
@ -305,12 +279,7 @@ def test_word_ladder_path_validation():
|
||||||
def test_word_ladder_solution_optimality():
|
def test_word_ladder_solution_optimality():
|
||||||
"""Test that generated solutions are optimal when min_chain_length=-1"""
|
"""Test that generated solutions are optimal when min_chain_length=-1"""
|
||||||
config = WordLadderConfig(
|
config = WordLadderConfig(
|
||||||
min_word_length=4,
|
min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=20, seed=42
|
||||||
max_word_length=4,
|
|
||||||
min_chain_length=-1,
|
|
||||||
max_chain_length=-1,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = WordLadderDataset(config)
|
dataset = WordLadderDataset(config)
|
||||||
|
|
||||||
|
|
@ -325,6 +294,7 @@ def test_word_ladder_solution_optimality():
|
||||||
|
|
||||||
# Build graph and use BFS to find shortest path
|
# Build graph and use BFS to find shortest path
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
queue = deque([(start_word, [start_word])])
|
queue = deque([(start_word, [start_word])])
|
||||||
visited = {start_word}
|
visited = {start_word}
|
||||||
shortest_path = None
|
shortest_path = None
|
||||||
|
|
@ -341,8 +311,9 @@ def test_word_ladder_solution_optimality():
|
||||||
queue.append((neighbor, path + [neighbor]))
|
queue.append((neighbor, path + [neighbor]))
|
||||||
|
|
||||||
assert shortest_path is not None, f"No path found between {start_word} and {end_word}"
|
assert shortest_path is not None, f"No path found between {start_word} and {end_word}"
|
||||||
assert len(solution_chain) == len(shortest_path), \
|
assert len(solution_chain) == len(
|
||||||
f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}"
|
shortest_path
|
||||||
|
), f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}"
|
||||||
|
|
||||||
|
|
||||||
def test_word_ladder_performance():
|
def test_word_ladder_performance():
|
||||||
|
|
@ -371,13 +342,7 @@ def test_word_ladder_edge_cases():
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
|
|
||||||
# Test with same start/end word length but maximum distance
|
# Test with same start/end word length but maximum distance
|
||||||
config = WordLadderConfig(
|
config = WordLadderConfig(min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=10)
|
||||||
min_word_length=4,
|
|
||||||
max_word_length=4,
|
|
||||||
min_chain_length=-1,
|
|
||||||
max_chain_length=-1,
|
|
||||||
size=10
|
|
||||||
)
|
|
||||||
dataset = WordLadderDataset(config)
|
dataset = WordLadderDataset(config)
|
||||||
|
|
||||||
# Find the pair with longest solution
|
# Find the pair with longest solution
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue