mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
Revert "Restructure {reasoning_gym, tests}/{core, exercises, curricula}"
This reverts commit 10dbb374b0.
This commit is contained in:
parent
b756f26c09
commit
4c3ae0aebf
109 changed files with 0 additions and 0 deletions
|
|
@ -1,165 +0,0 @@
|
|||
"""Tests for base conversion task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||
|
||||
|
||||
def test_base_conversion_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = BaseConversionConfig(min_base=1) # Too small
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BaseConversionConfig(min_base=37) # Too large
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BaseConversionConfig(min_base=10, max_base=5) # max < min
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BaseConversionConfig(min_value=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_base_conversion_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = BaseConversionConfig(seed=42, size=10)
|
||||
dataset1 = BaseConversionDataset(config)
|
||||
dataset2 = BaseConversionDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_base_conversion_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
|
||||
dataset = BaseConversionDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "decimal_value" in item["metadata"]
|
||||
assert "source_base" in item["metadata"]
|
||||
assert "target_base" in item["metadata"]
|
||||
assert "source_repr" in item["metadata"]
|
||||
assert "target_repr" in item["metadata"]
|
||||
|
||||
# Verify value range
|
||||
assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value
|
||||
|
||||
# Verify base range
|
||||
assert config.min_base <= item["metadata"]["source_base"] <= config.max_base
|
||||
assert config.min_base <= item["metadata"]["target_base"] <= config.max_base
|
||||
assert item["metadata"]["source_base"] != item["metadata"]["target_base"]
|
||||
|
||||
# Verify conversion correctness
|
||||
decimal_value = item["metadata"]["decimal_value"]
|
||||
target_base = item["metadata"]["target_base"]
|
||||
|
||||
# Use same conversion logic as implementation
|
||||
if target_base == 16:
|
||||
expected = format(decimal_value, "x")
|
||||
elif target_base == 2:
|
||||
expected = format(decimal_value, "b")
|
||||
else:
|
||||
# Manual conversion for other bases
|
||||
n = decimal_value
|
||||
digits = []
|
||||
while n:
|
||||
digits.append(int(n % target_base))
|
||||
n //= target_base
|
||||
expected = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
|
||||
assert item["answer"] == expected
|
||||
|
||||
|
||||
def test_base_conversion_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = BaseConversionConfig(size=5, seed=42)
|
||||
dataset = BaseConversionDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_base_conversion_validity():
|
||||
"""Test that generated numbers are valid for their bases"""
|
||||
config = BaseConversionConfig(min_base=2, max_base=36, min_value=0, max_value=1000, size=100, seed=42)
|
||||
dataset = BaseConversionDataset(config)
|
||||
|
||||
def is_valid_for_base(num_str: str, base: int) -> bool:
|
||||
valid_chars = "0123456789abcdefghijklmnopqrstuvwxyz"[:base]
|
||||
return all(c in valid_chars for c in num_str.lower())
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert is_valid_for_base(
|
||||
item["metadata"]["source_repr"], item["metadata"]["source_base"]
|
||||
), f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}"
|
||||
assert is_valid_for_base(
|
||||
item["metadata"]["target_repr"], item["metadata"]["target_base"]
|
||||
), f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}"
|
||||
|
||||
|
||||
def test_base_conversion_special_bases():
|
||||
"""Test conversion between special bases (binary, hex)"""
|
||||
config = BaseConversionConfig(
|
||||
min_base=2,
|
||||
max_base=16,
|
||||
min_value=0,
|
||||
max_value=255, # Use small range for predictable results
|
||||
size=100,
|
||||
seed=42,
|
||||
)
|
||||
dataset = BaseConversionDataset(config)
|
||||
|
||||
binary_found = False
|
||||
hex_found = False
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
if item["metadata"]["target_base"] == 2:
|
||||
binary_found = True
|
||||
# Verify binary format
|
||||
assert all(c in "01" for c in item["answer"])
|
||||
elif item["metadata"]["target_base"] == 16:
|
||||
hex_found = True
|
||||
# Verify hex format
|
||||
assert all(c in "0123456789abcdef" for c in item["answer"])
|
||||
|
||||
assert binary_found, "No binary conversion tasks generated"
|
||||
assert hex_found, "No hexadecimal conversion tasks generated"
|
||||
|
||||
|
||||
def test_base_conversion_formatting():
|
||||
"""Test number formatting in different bases"""
|
||||
config = BaseConversionConfig(
|
||||
min_base=11, # Force bases that use letters
|
||||
max_base=36,
|
||||
min_value=10, # Ensure multi-digit numbers
|
||||
max_value=1000,
|
||||
size=10,
|
||||
seed=42,
|
||||
)
|
||||
dataset = BaseConversionDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Verify lowercase letters are used
|
||||
assert item["answer"] == item["answer"].lower()
|
||||
# Verify no whitespace in answer
|
||||
assert item["answer"].strip() == item["answer"]
|
||||
# Verify hint is included for bases > 10
|
||||
assert "use lowercase letters" in item["question"]
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
"""Tests for Caesar cipher task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||
|
||||
|
||||
def test_caesar_cipher_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_words=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_rotation=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(max_rotation=26)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_caesar_cipher_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CaesarCipherConfig(seed=42, size=10)
|
||||
dataset1 = CaesarCipherDataset(config)
|
||||
dataset2 = CaesarCipherDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_caesar_cipher_encryption():
|
||||
"""Test the Caesar cipher encryption logic"""
|
||||
config = CaesarCipherConfig(size=1, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
|
||||
# Test with known rotation
|
||||
text = "HELLO"
|
||||
encrypted = dataset._caesar_encrypt(text, 1)
|
||||
assert encrypted == "IFMMP" # Each letter shifted by 1
|
||||
|
||||
# Test wrapping around Z
|
||||
encrypted = dataset._caesar_encrypt("XYZ", 2)
|
||||
assert encrypted == "ZAB"
|
||||
|
||||
# Test preserving spaces
|
||||
encrypted = dataset._caesar_encrypt("HELLO WORLD", 1)
|
||||
assert encrypted == "IFMMP XPSME"
|
||||
|
||||
|
||||
def test_caesar_cipher_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = CaesarCipherConfig(min_words=3, max_words=5, min_rotation=1, max_rotation=3, size=10, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "rotation" in item["metadata"]
|
||||
assert "cipher_text" in item["metadata"]
|
||||
assert "clear_text" in item["metadata"]
|
||||
|
||||
# Verify rotation constraints
|
||||
rotation = item["metadata"]["rotation"]
|
||||
assert config.min_rotation <= rotation <= config.max_rotation
|
||||
|
||||
# Verify text properties
|
||||
clear_text = item["metadata"]["clear_text"]
|
||||
words = clear_text.split()
|
||||
assert config.min_words <= len(words) <= config.max_words
|
||||
assert all(word.isupper() and word.isalpha() for word in words)
|
||||
|
||||
# Verify encryption
|
||||
cipher_text = item["metadata"]["cipher_text"]
|
||||
decrypted = dataset._caesar_encrypt(cipher_text, -rotation) # Decrypt by negative rotation
|
||||
assert decrypted == clear_text
|
||||
|
||||
|
||||
def test_caesar_cipher_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = CaesarCipherConfig(size=5, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
"""Tests for letter counting task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||
|
||||
|
||||
def test_letter_counting_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterCountingConfig(min_words=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterCountingConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_letter_counting_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = LetterCountingConfig(seed=42, size=10)
|
||||
dataset1 = LetterCountingDataset(config)
|
||||
dataset2 = LetterCountingDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_letter_counting_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||
dataset = LetterCountingDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "span_length" in item["metadata"]
|
||||
assert "target_letter" in item["metadata"]
|
||||
assert "span" in item["metadata"]
|
||||
|
||||
# Verify span length constraints
|
||||
span = item["metadata"]["span"]
|
||||
assert len(span) >= config.min_words
|
||||
assert len(span) <= config.max_words
|
||||
|
||||
# Verify letter counting
|
||||
target_letter = item["metadata"]["target_letter"]
|
||||
count = sum(word.lower().count(target_letter) for word in span)
|
||||
assert str(count) == item["answer"]
|
||||
|
||||
|
||||
def test_letter_counting_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = LetterCountingConfig(size=5, seed=42)
|
||||
dataset = LetterCountingDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_letter_counting_text_preprocessing():
|
||||
"""Test that text preprocessing handles edge cases"""
|
||||
config = LetterCountingConfig(size=1, seed=42)
|
||||
dataset = LetterCountingDataset(config)
|
||||
|
||||
# Verify words were extracted from text
|
||||
assert len(dataset.words) > 0
|
||||
# Verify words contain only word characters
|
||||
assert all(word.isalnum() for word in dataset.words)
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
"""Tests for letter jumbling task generation"""
|
||||
|
||||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.letter_jumble import LetterJumbleConfig, LetterJumbleDataset
|
||||
|
||||
|
||||
def test_letter_jumble_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterJumbleConfig(min_word_len=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterJumbleConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterJumbleConfig(min_corruption_level=-0.1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LetterJumbleConfig(max_corruption_level=1.1)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_letter_jumble_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = LetterJumbleConfig(seed=42, size=10)
|
||||
dataset1 = LetterJumbleDataset(config)
|
||||
dataset2 = LetterJumbleDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_letter_jumble_scrambling():
|
||||
"""Test the word scrambling logic"""
|
||||
config = LetterJumbleConfig(
|
||||
min_word_len=4,
|
||||
max_word_len=8,
|
||||
min_words=1,
|
||||
max_words=1,
|
||||
min_corruption_level=0.5,
|
||||
max_corruption_level=0.5,
|
||||
size=1,
|
||||
seed=42,
|
||||
)
|
||||
dataset = LetterJumbleDataset(config)
|
||||
|
||||
# Test with known word
|
||||
word = "testing"
|
||||
rng = Random(42)
|
||||
scrambled = dataset._scramble_word(word, 0.5, rng)
|
||||
|
||||
# Verify scrambled word:
|
||||
# - Has same length as original
|
||||
assert len(scrambled) == len(word)
|
||||
# - Contains same characters
|
||||
assert sorted(scrambled) == sorted(word)
|
||||
# - Is different from original (with high probability given 0.5 corruption)
|
||||
assert scrambled != word
|
||||
|
||||
|
||||
def test_letter_jumble_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = LetterJumbleConfig(
|
||||
min_word_len=4,
|
||||
max_word_len=8,
|
||||
min_words=3,
|
||||
max_words=5,
|
||||
min_corruption_level=0.1,
|
||||
max_corruption_level=0.3,
|
||||
size=50,
|
||||
seed=42,
|
||||
)
|
||||
dataset = LetterJumbleDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
metadata = item["metadata"]
|
||||
assert "num_words" in metadata
|
||||
assert "corruption_level" in metadata
|
||||
assert "scrambled_words" in metadata
|
||||
assert "original_words" in metadata
|
||||
|
||||
# Verify word counts
|
||||
num_words = metadata["num_words"]
|
||||
assert config.min_words <= num_words <= config.max_words
|
||||
assert len(metadata["scrambled_words"]) == num_words
|
||||
assert len(metadata["original_words"]) == num_words
|
||||
|
||||
# Verify corruption level
|
||||
assert config.min_corruption_level <= metadata["corruption_level"] <= config.max_corruption_level
|
||||
|
||||
# Verify word properties
|
||||
for word in metadata["original_words"]:
|
||||
assert config.min_word_len <= len(word) <= config.max_word_len
|
||||
assert word.isalpha()
|
||||
|
||||
|
||||
def test_letter_jumble_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = LetterJumbleConfig(size=5, seed=42)
|
||||
dataset = LetterJumbleDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
"""Tests for number filtering task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
|
||||
|
||||
def test_number_filtering_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFilteringConfig(min_numbers=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFilteringConfig(min_numbers=10, max_numbers=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFilteringConfig(min_decimals=-1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFilteringConfig(min_value=100, max_value=0)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_number_filtering_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = NumberFilteringConfig(seed=42, size=10)
|
||||
dataset1 = NumberFilteringDataset(config)
|
||||
dataset2 = NumberFilteringDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_number_filtering_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = NumberFilteringConfig(
|
||||
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||
)
|
||||
dataset = NumberFilteringDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "original_numbers" in item["metadata"]
|
||||
assert "filter_value" in item["metadata"]
|
||||
assert "operation" in item["metadata"]
|
||||
assert "result" in item["metadata"]
|
||||
|
||||
# Verify number count constraints
|
||||
numbers = item["metadata"]["original_numbers"]
|
||||
assert len(numbers) >= config.min_numbers
|
||||
assert len(numbers) <= config.max_numbers
|
||||
|
||||
# Verify decimal places
|
||||
for num in numbers:
|
||||
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||
assert decimal_places >= config.min_decimals
|
||||
assert decimal_places <= config.max_decimals
|
||||
|
||||
# Verify value range
|
||||
for num in numbers:
|
||||
value = float(num)
|
||||
assert config.min_value <= value <= config.max_value
|
||||
|
||||
# Verify filtering operation
|
||||
operation = item["metadata"]["operation"]
|
||||
filter_value = float(item["metadata"]["filter_value"])
|
||||
result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else []
|
||||
|
||||
if operation == "keep_larger":
|
||||
assert all(x > filter_value for x in result)
|
||||
elif operation == "keep_smaller":
|
||||
assert all(x < filter_value for x in result)
|
||||
elif operation == "remove_larger":
|
||||
assert all(x <= filter_value for x in result)
|
||||
elif operation == "remove_smaller":
|
||||
assert all(x >= filter_value for x in result)
|
||||
|
||||
|
||||
def test_number_filtering_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = NumberFilteringConfig(size=5, seed=42)
|
||||
dataset = NumberFilteringDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_number_filtering_precision():
|
||||
"""Test that number formatting and precision handling works correctly"""
|
||||
config = NumberFilteringConfig(
|
||||
min_numbers=3,
|
||||
max_numbers=3, # Fixed size for predictability
|
||||
min_decimals=2,
|
||||
max_decimals=2, # Fixed decimals for predictability
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
size=1,
|
||||
seed=42,
|
||||
)
|
||||
dataset = NumberFilteringDataset(config)
|
||||
item = dataset[0]
|
||||
|
||||
# Check that string representations maintain precision
|
||||
for num in item["metadata"]["original_numbers"]:
|
||||
assert len(num.split(".")[-1]) == 2
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
"""Tests for number sorting task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
|
||||
|
||||
def test_number_sorting_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberSortingConfig(min_numbers=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberSortingConfig(min_numbers=10, max_numbers=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberSortingConfig(min_decimals=-1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberSortingConfig(min_value=100, max_value=0)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_number_sorting_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = NumberSortingConfig(seed=42, size=10)
|
||||
dataset1 = NumberSortingDataset(config)
|
||||
dataset2 = NumberSortingDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_number_sorting_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = NumberSortingConfig(
|
||||
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||
)
|
||||
dataset = NumberSortingDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "original_numbers" in item["metadata"]
|
||||
assert "direction" in item["metadata"]
|
||||
assert "sorted_numbers" in item["metadata"]
|
||||
|
||||
# Verify number count constraints
|
||||
numbers = item["metadata"]["original_numbers"]
|
||||
assert len(numbers) >= config.min_numbers
|
||||
assert len(numbers) <= config.max_numbers
|
||||
|
||||
# Verify decimal places
|
||||
for num in numbers:
|
||||
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||
assert decimal_places >= config.min_decimals
|
||||
assert decimal_places <= config.max_decimals
|
||||
|
||||
# Verify value range
|
||||
for num in numbers:
|
||||
value = float(num)
|
||||
assert config.min_value <= value <= config.max_value
|
||||
|
||||
# Verify sorting
|
||||
direction = item["metadata"]["direction"]
|
||||
sorted_numbers = [float(x) for x in eval(item["answer"])]
|
||||
if direction == "ascending":
|
||||
assert sorted_numbers == sorted(sorted_numbers)
|
||||
else:
|
||||
assert sorted_numbers == sorted(sorted_numbers, reverse=True)
|
||||
|
||||
|
||||
def test_number_sorting_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = NumberSortingConfig(size=5, seed=42)
|
||||
dataset = NumberSortingDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return SentenceReorderingConfig(min_words_in_sentence=5, max_words_in_sentence=5, seed=42, size=10)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(config):
|
||||
return SentenceReorderingDataset(config=config)
|
||||
|
||||
|
||||
def test_config_validation(config):
|
||||
# Test that the config validation does not raise any exceptions
|
||||
try:
|
||||
config.validate()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Config validation raised an exception: {e}")
|
||||
|
||||
|
||||
def test_generate_sentence_dataset(dataset):
|
||||
sentence = "This is a test sentence for reordering"
|
||||
result = dataset._generate_sentence_dataset(sentence, seed=42, idx=0, shuffle=True)
|
||||
assert "input" in result
|
||||
assert "goal" in result
|
||||
assert result["input"] != result["goal"]
|
||||
assert sorted(result["input"].split()) == sorted(result["goal"].split())
|
||||
|
||||
|
||||
def test_getitem(dataset, config):
|
||||
item = dataset[0]
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
|
||||
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
|
||||
|
||||
|
||||
def test_key_error_in_getitem(dataset):
|
||||
# Modify the dataset to include an incorrect key
|
||||
def mock_generate_sentence_dataset(*args, **kwargs):
|
||||
return {"input": "mock input", "goal": "mock goal", "extra": "extra key"}
|
||||
|
||||
dataset._generate_sentence_dataset = mock_generate_sentence_dataset
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
dataset[0]
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
"""Tests for spell backward task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
|
||||
|
||||
def test_spell_backward_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = SpellBackwardConfig(min_word_len=0)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_spell_backward_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = SpellBackwardConfig(seed=42, size=10)
|
||||
dataset1 = SpellBackwardDataset(config)
|
||||
dataset2 = SpellBackwardDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_spell_backward_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = SpellBackwardConfig(min_word_len=3, size=10, seed=42)
|
||||
dataset = SpellBackwardDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "word" in item["metadata"]
|
||||
assert "word_len" in item["metadata"]
|
||||
|
||||
# Verify word length constraint
|
||||
word = item["metadata"]["word"]
|
||||
assert len(word) >= config.min_word_len
|
||||
|
||||
# Verify answer is correct
|
||||
assert item["answer"] == word[::-1]
|
||||
|
||||
|
||||
def test_spell_backward_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = SpellBackwardConfig(size=5, seed=42)
|
||||
dataset = SpellBackwardDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset
|
||||
|
||||
|
||||
def test_word_ladder_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
# Test min_word_length validation
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordLadderConfig(min_word_length=2)
|
||||
config.validate()
|
||||
|
||||
# Test max_word_length validation
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordLadderConfig(max_word_length=6)
|
||||
config.validate()
|
||||
|
||||
# Test word length relationship
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordLadderConfig(min_word_length=5, max_word_length=3)
|
||||
config.validate()
|
||||
|
||||
# Test min_chain_length validation
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordLadderConfig(min_chain_length=2)
|
||||
config.validate()
|
||||
|
||||
# Test chain length relationship
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordLadderConfig(min_chain_length=5, max_chain_length=3)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_word_ladder_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = WordLadderConfig(seed=42, size=10)
|
||||
dataset1 = WordLadderDataset(config)
|
||||
dataset2 = WordLadderDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_word_ladder_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = WordLadderConfig(
|
||||
min_word_length=3, max_word_length=5, min_chain_length=3, max_chain_length=5, size=10, seed=42
|
||||
)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
metadata = item["metadata"]
|
||||
assert "start_word" in metadata
|
||||
assert "end_word" in metadata
|
||||
assert "word_length" in metadata
|
||||
assert "chain_length" in metadata
|
||||
|
||||
# Verify word length constraints
|
||||
word_length = metadata["word_length"]
|
||||
assert config.min_word_length <= word_length <= config.max_word_length
|
||||
assert len(metadata["start_word"]) == word_length
|
||||
assert len(metadata["end_word"]) == word_length
|
||||
|
||||
# Verify solution chain from answer
|
||||
solution_chain = item["answer"].split(",")
|
||||
|
||||
# Handle chain length validation based on whether it's shortest path (-1) or specified length
|
||||
if metadata["chain_length"] == -1:
|
||||
# For shortest path, just ensure it's a valid path (we can't predict exact length)
|
||||
assert len(solution_chain) >= 2 # Must have at least start and end words
|
||||
else:
|
||||
# For specified length, ensure it matches config constraints
|
||||
assert config.min_chain_length <= len(solution_chain) <= config.max_chain_length
|
||||
assert len(solution_chain) == metadata["chain_length"]
|
||||
|
||||
assert solution_chain[0] == metadata["start_word"]
|
||||
assert solution_chain[-1] == metadata["end_word"]
|
||||
assert all(len(word) == word_length for word in solution_chain)
|
||||
|
||||
# Verify each step differs by only one letter
|
||||
for j in range(len(solution_chain) - 1):
|
||||
differences = sum(1 for a, b in zip(solution_chain[j], solution_chain[j + 1]) if a != b)
|
||||
assert differences == 1
|
||||
|
||||
|
||||
def test_word_ladder_differs_by_one():
|
||||
"""Test the _differs_by_one helper method"""
|
||||
config = WordLadderConfig()
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
# Test words that differ by one letter
|
||||
assert dataset._differs_by_one("CAT", "BAT")
|
||||
assert dataset._differs_by_one("DOG", "LOG")
|
||||
assert dataset._differs_by_one("WORD", "WARD")
|
||||
|
||||
# Test words that differ by more than one letter
|
||||
assert not dataset._differs_by_one("CAT", "DOG")
|
||||
assert not dataset._differs_by_one("WORD", "WAND")
|
||||
|
||||
# Test words of different lengths
|
||||
assert not dataset._differs_by_one("CAT", "CATS")
|
||||
assert not dataset._differs_by_one("DOG", "DO")
|
||||
|
||||
# Test identical words
|
||||
assert not dataset._differs_by_one("CAT", "CAT")
|
||||
|
||||
|
||||
def test_word_ladder_find_path():
|
||||
"""Test the _find_path helper method"""
|
||||
config = WordLadderConfig()
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
# Create a small test word set
|
||||
word_set = {"CAT", "BAT", "BAR", "CAR"}
|
||||
|
||||
# Test finding valid paths
|
||||
path1 = dataset._find_path("CAT", "BAR", word_set)
|
||||
assert path1 is not None
|
||||
assert path1[0] == "CAT"
|
||||
assert path1[-1] == "BAR"
|
||||
assert all(word in word_set for word in path1)
|
||||
|
||||
# Test when no path exists
|
||||
word_set = {"CAT", "DOG"}
|
||||
path2 = dataset._find_path("CAT", "DOG", word_set)
|
||||
assert path2 is None
|
||||
|
||||
# Test path to same word
|
||||
path3 = dataset._find_path("CAT", "CAT", word_set)
|
||||
assert path3 == ["CAT"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||
|
||||
|
||||
def test_word_sequence_reversal_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSequenceReversalConfig(min_words=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSequenceReversalConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_word_sequence_reversal_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = WordSequenceReversalConfig(seed=42, size=10)
|
||||
dataset1 = WordSequenceReversalDataset(config)
|
||||
dataset2 = WordSequenceReversalDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_word_sequence_reversal_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = WordSequenceReversalConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||
dataset = WordSequenceReversalDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "num_words" in item["metadata"]
|
||||
assert "words" in item["metadata"]
|
||||
|
||||
# Verify word count constraints
|
||||
words = item["metadata"]["words"]
|
||||
assert len(words) >= config.min_words
|
||||
assert len(words) <= config.max_words
|
||||
|
||||
# Verify reversal is correct
|
||||
question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")]
|
||||
answer_words = item["answer"].split(", ")
|
||||
assert answer_words == list(reversed(question_words))
|
||||
|
||||
|
||||
def test_word_sequence_reversal_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = WordSequenceReversalConfig(size=5, seed=42)
|
||||
dataset = WordSequenceReversalDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_word_sequence_reversal_text_preprocessing():
|
||||
"""Test that text preprocessing handles edge cases"""
|
||||
config = WordSequenceReversalConfig(size=1, seed=42)
|
||||
dataset = WordSequenceReversalDataset(config)
|
||||
|
||||
# Verify words were extracted from text
|
||||
assert len(dataset.words) > 0
|
||||
# Verify words contain only alphanumeric characters
|
||||
assert all(word.isalnum() for word in dataset.words)
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
"""Tests for word sorting task generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||
|
||||
|
||||
def test_word_sorting_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSortingConfig(min_words=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSortingConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSortingConfig(min_word_length=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WordSortingConfig(min_word_length=10, max_word_length=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_word_sorting_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = WordSortingConfig(seed=42, size=10)
|
||||
dataset1 = WordSortingDataset(config)
|
||||
dataset2 = WordSortingDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_word_sorting_transformations():
|
||||
"""Test different text transformations"""
|
||||
seed = 42
|
||||
size = 5
|
||||
|
||||
# Test LOWERCASE
|
||||
config = WordSortingConfig(transformation=TextTransformation.LOWERCASE, seed=seed, size=size)
|
||||
dataset = WordSortingDataset(config)
|
||||
for item in dataset:
|
||||
for word in item["metadata"]["transformed_words"]:
|
||||
if word.isalpha(): # Only test alphabetic strings
|
||||
assert word.islower()
|
||||
|
||||
# Test UPPERCASE
|
||||
config = WordSortingConfig(transformation=TextTransformation.UPPERCASE, seed=seed, size=size)
|
||||
dataset = WordSortingDataset(config)
|
||||
for item in dataset:
|
||||
for word in item["metadata"]["transformed_words"]:
|
||||
if word.isalpha(): # Only test alphabetic strings
|
||||
assert word.isupper()
|
||||
|
||||
# Test ORIGINAL
|
||||
config = WordSortingConfig(transformation=TextTransformation.ORIGINAL, seed=seed, size=size)
|
||||
dataset = WordSortingDataset(config)
|
||||
for item in dataset:
|
||||
assert item["metadata"]["original_words"] == item["metadata"]["transformed_words"]
|
||||
|
||||
|
||||
def test_word_sorting_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = WordSortingConfig(min_words=3, max_words=6, min_word_length=3, max_word_length=8, size=10, seed=42)
|
||||
dataset = WordSortingDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "original_words" in item["metadata"]
|
||||
assert "transformed_words" in item["metadata"]
|
||||
assert "direction" in item["metadata"]
|
||||
assert "transformation" in item["metadata"]
|
||||
assert "sorted_words" in item["metadata"]
|
||||
|
||||
# Verify word count constraints
|
||||
words = item["metadata"]["transformed_words"]
|
||||
assert len(words) >= config.min_words
|
||||
assert len(words) <= config.max_words
|
||||
|
||||
# Verify word length constraints
|
||||
for word in words:
|
||||
assert len(word) >= config.min_word_length
|
||||
assert len(word) <= config.max_word_length
|
||||
|
||||
# Verify sorting
|
||||
direction = item["metadata"]["direction"]
|
||||
sorted_words = item["answer"].split(", ")
|
||||
if direction == "ascending":
|
||||
assert sorted_words == sorted(sorted_words)
|
||||
else:
|
||||
assert sorted_words == sorted(sorted_words, reverse=True)
|
||||
|
||||
|
||||
def test_word_sorting_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = WordSortingConfig(size=5, seed=42)
|
||||
dataset = WordSortingDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
Loading…
Add table
Add a link
Reference in a new issue