mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-02 17:45:58 +00:00
Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath
This commit is contained in:
commit
0cd2eb50d7
62 changed files with 4012 additions and 478 deletions
|
|
@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes():
|
|||
both_ds = ArcAgiDataset(both_config)
|
||||
assert len(both_ds._task_ids) > len(train_ds._task_ids)
|
||||
assert len(both_ds._task_ids) > len(eval_ds._task_ids)
|
||||
|
||||
|
||||
def test_arc_agi_shuffled_order():
|
||||
config_unshuffled = ArcAgiConfig(
|
||||
shuffle_example_order=False,
|
||||
use_train=True,
|
||||
use_eval=False,
|
||||
rotations=[],
|
||||
mirrors=[],
|
||||
use_color_permutation=False,
|
||||
size=3,
|
||||
seed=42,
|
||||
)
|
||||
config_shuffled = ArcAgiConfig(
|
||||
shuffle_example_order=True,
|
||||
use_train=True,
|
||||
use_eval=False,
|
||||
rotations=[],
|
||||
mirrors=[],
|
||||
use_color_permutation=False,
|
||||
size=3,
|
||||
seed=42,
|
||||
)
|
||||
unshuffled = ArcAgiDataset(config_unshuffled)
|
||||
shuffled = ArcAgiDataset(config_shuffled)
|
||||
|
||||
for a, b in zip(shuffled, unshuffled):
|
||||
assert a["question"] != b["question"]
|
||||
assert a["answer"] == b["answer"]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.basic_arithmetic import (
|
||||
|
|
@ -64,11 +62,19 @@ def test_arithmetic_dataset_format_styles():
|
|||
max_digits=2,
|
||||
)
|
||||
dataset = BasicArithmeticDataset(config)
|
||||
assert all(item["question"].endswith("=") for item in dataset)
|
||||
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||
|
||||
config.format_style = "natural"
|
||||
config = BasicArithmeticDatasetConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
format_style="natural",
|
||||
min_terms=2,
|
||||
max_terms=3, # Keep expressions simple for testing
|
||||
min_digits=1,
|
||||
max_digits=2,
|
||||
)
|
||||
dataset = BasicArithmeticDataset(config)
|
||||
assert all("=" not in item["question"] for item in dataset)
|
||||
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||
|
||||
|
||||
def test_arithmetic_dataset_iteration():
|
||||
|
|
|
|||
|
|
@ -15,6 +15,14 @@ def test_binary_matrix_config_validation():
|
|||
config = BinaryMatrixConfig(max_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BinaryMatrixConfig(min_n=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BinaryMatrixConfig(min_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed
|
||||
config.validate()
|
||||
|
|
@ -98,3 +106,18 @@ def test_binary_matrix_answer():
|
|||
# Empty matrix
|
||||
matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
||||
assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
||||
|
||||
# String representation of answer
|
||||
answer = "0 0 0\n0 1 0\n1 2 1"
|
||||
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||
assert dataset.score_answer(answer, entry) == 1.0
|
||||
|
||||
# Answer is a python list (partially correct answer)
|
||||
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
|
||||
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||
assert dataset.score_answer(answer, entry) == 0.5
|
||||
|
||||
# Answer is null
|
||||
answer = None
|
||||
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||
assert dataset.score_answer(answer, entry) == 0.0
|
||||
|
|
|
|||
224
tests/test_circuit_logic.py
Normal file
224
tests/test_circuit_logic.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset
|
||||
|
||||
|
||||
def test_circuit_logic_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CircuitLogicConfig(min_inputs=3, max_inputs=2)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CircuitLogicConfig(num_terms=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CircuitLogicConfig(neg_prob=-0.1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CircuitLogicConfig(neg_prob=1.1)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_circuit_logic_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CircuitLogicConfig(seed=42, size=10)
|
||||
dataset1 = CircuitLogicDataset(config)
|
||||
dataset2 = CircuitLogicDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_circuit_logic_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Verify metadata contents
|
||||
metadata = item["metadata"]
|
||||
assert "expression" in metadata
|
||||
assert "assignments" in metadata
|
||||
assert "final_gate" in metadata
|
||||
assert "inputs" in metadata
|
||||
|
||||
# Verify answer is binary
|
||||
assert item["answer"] in ("0", "1")
|
||||
|
||||
# Verify assignments are binary
|
||||
for input_name, value in metadata["assignments"].items():
|
||||
assert value in (0, 1)
|
||||
|
||||
# Verify final gate is valid
|
||||
assert metadata["final_gate"] in ("OR", "NOR", "XOR", "AND")
|
||||
|
||||
# Verify inputs list matches assignments
|
||||
assert set(metadata["inputs"]) == set(metadata["assignments"].keys())
|
||||
|
||||
|
||||
def test_circuit_logic_expression_validity():
|
||||
"""Test that generated expressions follow logical circuit rules"""
|
||||
config = CircuitLogicConfig(
|
||||
num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing
|
||||
)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
|
||||
# Expression should contain valid operators
|
||||
expr = metadata["expression"]
|
||||
assert any(op in expr for op in ("&", "↑", "⊕", "+", "↓"))
|
||||
|
||||
# Input names should be valid Excel-style names
|
||||
for input_name in metadata["inputs"]:
|
||||
assert input_name.isalpha()
|
||||
assert input_name.isupper()
|
||||
|
||||
|
||||
def test_circuit_logic_answer_verification():
|
||||
"""Test that answers match logical evaluation of circuits"""
|
||||
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
def evaluate_term(term: str, assignments: dict) -> int:
|
||||
"""Evaluate a single term with given assignments"""
|
||||
if "↑" in term: # NAND
|
||||
parts = term.split("↑")
|
||||
values = []
|
||||
for p in parts:
|
||||
if p.endswith("'"):
|
||||
values.append(1 - assignments[p[:-1]])
|
||||
else:
|
||||
values.append(assignments[p])
|
||||
return 0 if all(v == 1 for v in values) else 1
|
||||
elif "&" in term: # AND
|
||||
parts = term.split("&")
|
||||
values = []
|
||||
for p in parts:
|
||||
if p.endswith("'"):
|
||||
values.append(1 - assignments[p[:-1]])
|
||||
else:
|
||||
values.append(assignments[p])
|
||||
return 1 if all(v == 1 for v in values) else 0
|
||||
elif "⊕" in term: # XOR
|
||||
parts = term.split("⊕")
|
||||
values = []
|
||||
for p in parts:
|
||||
if p.endswith("'"):
|
||||
values.append(1 - assignments[p[:-1]])
|
||||
else:
|
||||
values.append(assignments[p])
|
||||
return sum(values) % 2
|
||||
else:
|
||||
raise ValueError(f"Unknown operator in term: {term}")
|
||||
|
||||
def evaluate_final_gate(gate_type: str, term_values: list) -> int:
|
||||
"""Evaluate the final gate with given term values"""
|
||||
if gate_type == "AND":
|
||||
return 1 if all(v == 1 for v in term_values) else 0
|
||||
elif gate_type == "OR":
|
||||
return 1 if any(v == 1 for v in term_values) else 0
|
||||
elif gate_type == "XOR":
|
||||
return sum(term_values) % 2
|
||||
elif gate_type == "NOR":
|
||||
return 0 if any(v == 1 for v in term_values) else 1
|
||||
else:
|
||||
raise ValueError(f"Unknown gate type: {gate_type}")
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
assignments = metadata["assignments"]
|
||||
final_gate = metadata["final_gate"]
|
||||
term_strings = metadata["term_strings"]
|
||||
|
||||
# First evaluate each term
|
||||
term_values = [evaluate_term(term, assignments) for term in term_strings]
|
||||
|
||||
# Then combine terms with final gate
|
||||
expected = evaluate_final_gate(final_gate, term_values)
|
||||
|
||||
# Compare with actual result
|
||||
result = int(item["answer"])
|
||||
assert (
|
||||
result == expected
|
||||
), f"Item {i}: Expected {expected} but got {result} for terms {term_strings} with assignments {assignments} and final gate {final_gate}"
|
||||
|
||||
|
||||
def test_circuit_logic_ascii_diagram():
|
||||
"""Test properties of the ASCII circuit diagram"""
|
||||
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
# Split question to get diagram
|
||||
parts = item["question"].split("\n")
|
||||
diagram_start = parts.index("Below is a randomly generated logic circuit.") + 2
|
||||
diagram_end = parts.index("", diagram_start)
|
||||
diagram = parts[diagram_start:diagram_end]
|
||||
|
||||
# Basic diagram validation
|
||||
assert len(diagram) > 0
|
||||
assert all(len(row) > 0 for row in diagram)
|
||||
|
||||
# Check for required circuit elements
|
||||
diagram_str = "\n".join(diagram)
|
||||
assert "OUT" in diagram_str
|
||||
assert any(gate in diagram_str for gate in ("&", "↑", "⊕"))
|
||||
|
||||
# Verify input labels
|
||||
for input_name in item["metadata"]["inputs"]:
|
||||
assert f"{input_name}:" in diagram_str
|
||||
|
||||
|
||||
def test_circuit_logic_scoring():
|
||||
"""Test the answer scoring mechanism"""
|
||||
config = CircuitLogicConfig(size=5, seed=42)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
# Correct answer should score 1.0
|
||||
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||
|
||||
# Wrong answer should score lower
|
||||
wrong_answer = "1" if item["answer"] == "0" else "0"
|
||||
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||
|
||||
# None or empty answer should score 0.0
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
assert dataset.score_answer("", item) == 0.0 # Empty string should score 0.0 like None
|
||||
|
||||
|
||||
def test_circuit_logic_iteration():
|
||||
"""Test that iteration works correctly"""
|
||||
config = CircuitLogicConfig(size=5, seed=42)
|
||||
dataset = CircuitLogicDataset(config)
|
||||
|
||||
# Test manual iteration
|
||||
items = []
|
||||
for item in dataset:
|
||||
items.append(item)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test list conversion
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
first_items = list(dataset)
|
||||
second_items = list(dataset)
|
||||
assert first_items == second_items
|
||||
105
tests/test_cryptarithm.py
Normal file
105
tests/test_cryptarithm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym import create_dataset
|
||||
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmDataset
|
||||
|
||||
|
||||
def test_cryptarithm_generation():
|
||||
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||
assert isinstance(dataset, CryptarithmDataset)
|
||||
unique_number = set()
|
||||
for item in dataset:
|
||||
# Check required keys exist
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Validate question format
|
||||
question = item["question"]
|
||||
assert "Solve this cryptarithm:" in question
|
||||
assert "Each letter stands for a unique digit (0-9)" in question
|
||||
|
||||
# Validate metadata structure
|
||||
metadata = item["metadata"]
|
||||
assert "letters" in metadata
|
||||
assert "letter_to_digit" in metadata
|
||||
assert "words_letters" in metadata
|
||||
assert "result_letters" in metadata
|
||||
assert "word_values" in metadata
|
||||
assert "sum_number" in metadata
|
||||
|
||||
# Validate letter to digit mapping
|
||||
letter_to_digit = metadata["letter_to_digit"]
|
||||
used_digits = set(letter_to_digit.values())
|
||||
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
|
||||
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
|
||||
|
||||
# Validate the arithmetic
|
||||
word_values = metadata["word_values"]
|
||||
result_value = metadata["sum_number"]
|
||||
assert sum(word_values) == result_value, "Sum of word values should equal result value"
|
||||
unique_number.add(result_value)
|
||||
|
||||
assert len(unique_number) == len(dataset)
|
||||
|
||||
|
||||
def test_cryptarithm_config():
|
||||
# Test invalid configs raise assertions
|
||||
with pytest.raises(AssertionError):
|
||||
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset = create_dataset("cryptarithm", size=0) # size must be positive
|
||||
|
||||
|
||||
def test_leading_zero_constraint():
|
||||
# Test with leading zeros not allowed
|
||||
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
|
||||
|
||||
for item in dataset:
|
||||
# print(item['question'])
|
||||
metadata = item["metadata"]
|
||||
letter_to_digit = metadata["letter_to_digit"]
|
||||
words_letters = metadata["words_letters"]
|
||||
result_letters = metadata["result_letters"]
|
||||
|
||||
# Check leading letters of all words and result
|
||||
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
|
||||
for letter in leading_letters:
|
||||
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
|
||||
|
||||
|
||||
def test_deterministic_generation():
|
||||
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
|
||||
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
|
||||
|
||||
for i in range(5):
|
||||
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
||||
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
|
||||
|
||||
|
||||
def test_word_length_constraints():
|
||||
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||
|
||||
for item in dataset:
|
||||
metadata = item["metadata"]
|
||||
words_letters = metadata["words_letters"]
|
||||
|
||||
# Check each word is between 3-5 letters as specified in the code
|
||||
for word in words_letters:
|
||||
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
|
||||
|
||||
|
||||
def test_max_letters_constraint():
|
||||
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||
|
||||
for item in dataset:
|
||||
metadata = item["metadata"]
|
||||
letter_to_digit = metadata["letter_to_digit"]
|
||||
|
||||
# Check total unique letters doesn't exceed 10 (digits 0-9)
|
||||
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
|
||||
188
tests/test_futoshiki.py
Normal file
188
tests/test_futoshiki.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.games import FutoshikiConfig, FutoshikiDataset
|
||||
|
||||
|
||||
def test_futoshiki_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = FutoshikiConfig(board_size=3) # Too small
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = FutoshikiConfig(board_size=10) # Too large
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = FutoshikiConfig(difficulty=-1) # Invalid difficulty
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = FutoshikiConfig(difficulty=4) # Invalid difficulty
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_futoshiki_deterministic():
|
||||
"""Test that dataset generates same puzzles with same seed"""
|
||||
config = FutoshikiConfig(seed=42, size=10)
|
||||
dataset1 = FutoshikiDataset(config)
|
||||
dataset2 = FutoshikiDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_futoshiki_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
|
||||
dataset = FutoshikiDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Verify metadata contents
|
||||
metadata = item["metadata"]
|
||||
assert "puzzle" in metadata
|
||||
assert "solution" in metadata
|
||||
assert "constraints" in metadata
|
||||
assert "board_size" in metadata
|
||||
assert "difficulty" in metadata
|
||||
|
||||
# Verify board dimensions
|
||||
puzzle = metadata["puzzle"]
|
||||
solution = metadata["solution"]
|
||||
assert len(puzzle) == config.board_size
|
||||
assert len(solution) == config.board_size
|
||||
for row in puzzle:
|
||||
assert len(row) == config.board_size
|
||||
for row in solution:
|
||||
assert len(row) == config.board_size
|
||||
|
||||
# Verify constraints format
|
||||
constraints = metadata["constraints"]
|
||||
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||
assert 0 <= r1 < config.board_size
|
||||
assert 0 <= c1 < config.board_size
|
||||
assert 0 <= r2 < config.board_size
|
||||
assert 0 <= c2 < config.board_size
|
||||
assert rel in ("<", ">")
|
||||
|
||||
|
||||
def test_futoshiki_solution_validity():
|
||||
"""Test that solutions are valid according to Futoshiki rules"""
|
||||
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
|
||||
dataset = FutoshikiDataset(config)
|
||||
|
||||
def is_valid_solution(solution, board_size, constraints):
|
||||
# Check rows
|
||||
for row in solution:
|
||||
if sorted(row) != list(range(1, board_size + 1)):
|
||||
return False
|
||||
|
||||
# Check columns
|
||||
for col in range(board_size):
|
||||
column = [solution[row][col] for row in range(board_size)]
|
||||
if sorted(column) != list(range(1, board_size + 1)):
|
||||
return False
|
||||
|
||||
# Check constraints
|
||||
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||
v1, v2 = solution[r1][c1], solution[r2][c2]
|
||||
if rel == "<" and not (v1 < v2):
|
||||
return False
|
||||
if rel == ">" and not (v1 > v2):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
solution = metadata["solution"]
|
||||
constraints = metadata["constraints"]
|
||||
|
||||
assert is_valid_solution(solution, config.board_size, constraints)
|
||||
|
||||
|
||||
def test_futoshiki_puzzle_solvability():
|
||||
"""Test that generated puzzles are solvable and have unique solutions"""
|
||||
config = FutoshikiConfig(board_size=4, difficulty=1, size=5, seed=42)
|
||||
dataset = FutoshikiDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
puzzle = metadata["puzzle"]
|
||||
constraints = metadata["constraints"]
|
||||
|
||||
# Verify puzzle has exactly one solution
|
||||
assert dataset.count_solutions(puzzle, constraints, limit=2) == 1
|
||||
|
||||
|
||||
def test_futoshiki_difficulty_levels():
|
||||
"""Test that different difficulty levels affect puzzle complexity"""
|
||||
size = 5
|
||||
board_size = 4
|
||||
seeds = [42, 43, 44] # Test multiple seeds for robustness
|
||||
|
||||
def count_clues(puzzle):
|
||||
return sum(cell != 0 for row in puzzle for cell in row)
|
||||
|
||||
def count_constraints(constraints):
|
||||
return len(constraints)
|
||||
|
||||
for seed in seeds:
|
||||
clues_by_difficulty = []
|
||||
constraints_by_difficulty = []
|
||||
|
||||
for difficulty in range(4): # 0 to 3
|
||||
config = FutoshikiConfig(board_size=board_size, difficulty=difficulty, size=size, seed=seed)
|
||||
dataset = FutoshikiDataset(config)
|
||||
|
||||
avg_clues = sum(count_clues(item["metadata"]["puzzle"]) for item in dataset) / size
|
||||
avg_constraints = sum(count_constraints(item["metadata"]["constraints"]) for item in dataset) / size
|
||||
|
||||
clues_by_difficulty.append(avg_clues)
|
||||
constraints_by_difficulty.append(avg_constraints)
|
||||
|
||||
# Higher difficulty should generally mean fewer clues and/or more constraints
|
||||
assert all(clues_by_difficulty[i] >= clues_by_difficulty[i + 1] for i in range(len(clues_by_difficulty) - 1))
|
||||
assert all(
|
||||
constraints_by_difficulty[i] <= constraints_by_difficulty[i + 1]
|
||||
for i in range(len(constraints_by_difficulty) - 1)
|
||||
)
|
||||
|
||||
|
||||
def test_futoshiki_answer_scoring():
|
||||
"""Test the answer scoring mechanism"""
|
||||
config = FutoshikiConfig(board_size=4, difficulty=0, size=5, seed=42)
|
||||
dataset = FutoshikiDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
# Correct answer should score 1.0
|
||||
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||
|
||||
# Wrong answer should score lower
|
||||
wrong_answer = item["answer"].replace("1", "2")
|
||||
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||
|
||||
# None or empty answer should score 0.0
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
assert dataset.score_answer("", item) == 0.0
|
||||
|
||||
answer = item["answer"]
|
||||
white_space_mismatch = answer.replace(" ", " ")
|
||||
assert dataset.score_answer(white_space_mismatch, item) == 0.9
|
||||
|
||||
anwser_with_additional_text = "This is an anwser " + answer + "\nwith surrounding text."
|
||||
assert 0 < dataset.score_answer(anwser_with_additional_text, item) < 0.9
|
||||
|
||||
partially_correct = anwser_with_additional_text.replace("1", "2")
|
||||
assert dataset.score_answer(partially_correct, item) > 0.1
|
||||
|
||||
bad_answer = "\n".join(anwser_with_additional_text.split("\n")[::-1])
|
||||
assert dataset.score_answer(bad_answer, item) < 0.1
|
||||
|
|
@ -122,6 +122,11 @@ def test_nqueens_score_answer():
|
|||
# Test None answer gets score 0.0
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
|
||||
# Test python list representation of board (partial solution)
|
||||
answer = "[['_', 'Q', '_', '_'], ['_', '_', '_', 'Q'], ['Q', '_', '_', '_'], ['_', '_', 'Q', '_']]"
|
||||
entry = {"metadata": {"valid_answers": {"_ Q _ _\n_ _ _ Q\nQ _ _ _\n_ _ Q _"}}}
|
||||
assert dataset.score_answer(answer, entry) == 0.5
|
||||
|
||||
|
||||
def is_valid_solution(board: list[list[str]]) -> bool:
|
||||
"""Helper function to verify N Queens solution validity"""
|
||||
|
|
|
|||
111
tests/test_palindrome_partitioning.py
Normal file
111
tests/test_palindrome_partitioning.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for Palindrome Partitioning questions generation"""
|
||||
|
||||
import json
|
||||
|
||||
from reasoning_gym.algorithmic.palindrome_partitioning import (
|
||||
PalindromePartitioningConfig,
|
||||
PalindromePartitioningDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_palindrome_partitioning_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = PalindromePartitioningConfig(seed=42, size=10)
|
||||
dataset1 = PalindromePartitioningDataset(config)
|
||||
dataset2 = PalindromePartitioningDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_palindrome_partitioning_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = PalindromePartitioningConfig(size=10, seed=42)
|
||||
dataset = PalindromePartitioningDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "string" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
string = item["metadata"]["string"]
|
||||
solution = item["metadata"]["solution"]
|
||||
|
||||
# Verify string is not empty
|
||||
assert len(string) > 0
|
||||
|
||||
# At least one partitioning exists (each letter is a palindrome)
|
||||
assert len(solution) >= 1
|
||||
|
||||
# Verify each partitioning reconstructs the original string
|
||||
assert all(len(partitioning) > 0 for partitioning in solution)
|
||||
assert all("".join(partitioning) == string for partitioning in solution)
|
||||
|
||||
|
||||
def test_palindrome_partitioning_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = PalindromePartitioningConfig(size=5, seed=42)
|
||||
dataset = PalindromePartitioningDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_palindrome_partitioning_answer():
|
||||
"""Test the _palindrome_partitioning method"""
|
||||
config = PalindromePartitioningConfig(seed=42)
|
||||
dataset = PalindromePartitioningDataset(config)
|
||||
|
||||
# General use case
|
||||
word = "afternoon"
|
||||
correct = [
|
||||
["a", "f", "t", "e", "r", "n", "o", "o", "n"],
|
||||
["a", "f", "t", "e", "r", "n", "oo", "n"],
|
||||
["a", "f", "t", "e", "r", "noon"],
|
||||
]
|
||||
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||
|
||||
# Single letter word
|
||||
word = "a"
|
||||
correct = [["a"]]
|
||||
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||
|
||||
# Empty string
|
||||
word = ""
|
||||
correct = []
|
||||
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||
|
||||
|
||||
def test_palindrome_partitioning_score_answer():
|
||||
"""Test the score_answer method"""
|
||||
config = PalindromePartitioningConfig(seed=42)
|
||||
dataset = PalindromePartitioningDataset(config)
|
||||
|
||||
# Verify the scoring function is permutation invariant
|
||||
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"], ["noon"]])
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 1
|
||||
|
||||
# Verify the score is 0.01 when incorrect
|
||||
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0.01
|
||||
|
||||
# Verify the score is 0 when answer is None
|
||||
answer = None
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0
|
||||
|
||||
# Verify the score is 0 when answer is malformed JSON
|
||||
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import string
|
||||
|
||||
import pytest
|
||||
import sympy as sp
|
||||
|
||||
|
|
@ -17,7 +19,7 @@ def test_polynomial_config_validation():
|
|||
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
||||
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||
|
|
@ -28,6 +30,17 @@ def test_polynomial_config_validation():
|
|||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(variables="").validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(
|
||||
allow_cross_variable_product=False, allow_multivariate_polynomials=True
|
||||
).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||
|
||||
|
||||
def test_polynomial_multiplication_dataset_basic():
|
||||
"""Test dataset creation and length"""
|
||||
|
|
@ -41,7 +54,9 @@ def test_polynomial_multiplication_dataset_basic():
|
|||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=3,
|
||||
single_variable=True,
|
||||
variables=tuple(string.ascii_lowercase),
|
||||
allow_cross_variable_product=False,
|
||||
allow_multivariate_polynomials=False,
|
||||
seed=42,
|
||||
size=dataset_size,
|
||||
)
|
||||
|
|
@ -63,7 +78,9 @@ def test_polynomial_equations_dataset_items():
|
|||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
variables=tuple("xyz"),
|
||||
allow_cross_variable_product=False,
|
||||
allow_multivariate_polynomials=False,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
|
@ -75,7 +92,113 @@ def test_polynomial_equations_dataset_items():
|
|||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["single_variable"], bool)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sp.sympify(poly_str)
|
||||
|
||||
|
||||
def test_cross_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple("xyz"),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=False,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sp.sympify(poly_str)
|
||||
|
||||
|
||||
def test_cross_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple("xyz"),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=False,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sp.sympify(poly_str)
|
||||
|
||||
|
||||
def test_multivariate_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
|
|
@ -105,7 +228,9 @@ def test_polynomial_solutions_evaluation():
|
|||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=5,
|
||||
seed=42,
|
||||
)
|
||||
|
|
@ -125,42 +250,27 @@ def test_score_function():
|
|||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
max_value=3,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=True,
|
||||
size=1,
|
||||
min_polynomials=3,
|
||||
max_polynomials=3,
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=3,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]) == 0.00
|
||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||
for item in ds:
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
assert ds.score_answer(poly_str, item) == 0.05
|
||||
|
||||
poly_expr = str(sp.expand(poly_str))
|
||||
assert ds.score_answer(poly_expr, item) == 1.0
|
||||
|
||||
def test_multivariate_score_function():
|
||||
"""Test that solution satisfy the polynomial multiplication."""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
size=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]) == 0.00
|
||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||
assert ds.score_answer(None, item) == 0.00
|
||||
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||
assert ds.score_answer("x**4", item) == 0.05
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def test_getitem(dataset, config):
|
|||
assert "metadata" in item
|
||||
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
|
||||
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
|
||||
assert len(item["answer"].split()) == item["metadata"]["word_count"]
|
||||
|
||||
|
||||
def test_key_error_in_getitem(dataset):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation():
|
|||
config = SpiralMatrixConfig(max_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = SpiralMatrixConfig(max_n=1) # One not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_spiral_matrix_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
|
|
@ -69,18 +73,26 @@ def test_spiral_matrix_answer():
|
|||
config = SpiralMatrixConfig(seed=42)
|
||||
dataset = SpiralMatrixDataset(config)
|
||||
|
||||
# One element
|
||||
matrix = [[0]]
|
||||
assert dataset._get_spiral(matrix) == [0]
|
||||
|
||||
# One row
|
||||
matrix = [[0, 1, 2]]
|
||||
assert dataset._get_spiral(matrix) == [0, 1, 2]
|
||||
|
||||
# One column
|
||||
matrix = [[0], [1], [2]]
|
||||
assert dataset._get_spiral(matrix) == [0, 1, 2]
|
||||
|
||||
# 2D grid
|
||||
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
|
||||
|
||||
# Answer is identical (up to trimming)
|
||||
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||
answer = "\n\n1 2 3 6 9 8 7 4 5\n"
|
||||
assert dataset.score_answer(answer, entry) == 1.0
|
||||
|
||||
# Score answer in list format (partially correct)
|
||||
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||
answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]"
|
||||
assert dataset.score_answer(answer, entry) == 0.5
|
||||
|
||||
# Answer is incorrect
|
||||
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||
answer = "1 2 3"
|
||||
assert dataset.score_answer(answer, entry) == 0.01
|
||||
|
||||
# Answer is none
|
||||
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||
answer = None
|
||||
assert dataset.score_answer(answer, entry) == 0.0
|
||||
|
|
|
|||
|
|
@ -92,3 +92,13 @@ def test_string_insertion_answer():
|
|||
|
||||
# No reuse of newly inserted characters
|
||||
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
|
||||
|
||||
# Test score_answer with correct answer
|
||||
answer = "AABCDAEEEEEEEBCDEBAAAAA"
|
||||
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
|
||||
assert dataset.score_answer(answer, entry) == 1.0
|
||||
|
||||
# Test score_answer with correct answer as python list of characters (partial correct)
|
||||
answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']"
|
||||
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
|
||||
assert dataset.score_answer(answer, entry) == 0.5
|
||||
|
|
|
|||
|
|
@ -116,3 +116,35 @@ def test_word_sorting_dataset_iteration():
|
|||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_word_sorting_scoring():
|
||||
"""Test scoring function"""
|
||||
config = WordSortingConfig(size=1, seed=42)
|
||||
dataset = WordSortingDataset(config)
|
||||
|
||||
item = {
|
||||
"metadata": {
|
||||
"sorted_words": ["apple", "banana", "cherry"],
|
||||
}
|
||||
}
|
||||
|
||||
# Correct answer
|
||||
answer = "apple, banana, cherry"
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# Correct answer, with incorrect spaces
|
||||
answer = "apple,banana, cherry"
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# All words present, but not sorted
|
||||
answer = "banana, cherry, apple"
|
||||
assert dataset.score_answer(answer, item) == 0.2
|
||||
|
||||
# Garbage
|
||||
answer = "gibberish"
|
||||
assert dataset.score_answer(answer, item) == 0.01
|
||||
|
||||
# Empty answer
|
||||
answer = None
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue