Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath

This commit is contained in:
Rich Jones 2025-02-19 03:34:57 +01:00
commit 0cd2eb50d7
62 changed files with 4012 additions and 478 deletions

View file

@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes():
both_ds = ArcAgiDataset(both_config)
assert len(both_ds._task_ids) > len(train_ds._task_ids)
assert len(both_ds._task_ids) > len(eval_ds._task_ids)
def test_arc_agi_shuffled_order():
config_unshuffled = ArcAgiConfig(
shuffle_example_order=False,
use_train=True,
use_eval=False,
rotations=[],
mirrors=[],
use_color_permutation=False,
size=3,
seed=42,
)
config_shuffled = ArcAgiConfig(
shuffle_example_order=True,
use_train=True,
use_eval=False,
rotations=[],
mirrors=[],
use_color_permutation=False,
size=3,
seed=42,
)
unshuffled = ArcAgiDataset(config_unshuffled)
shuffled = ArcAgiDataset(config_shuffled)
for a, b in zip(shuffled, unshuffled):
assert a["question"] != b["question"]
assert a["answer"] == b["answer"]

View file

@ -1,5 +1,3 @@
from random import Random
import pytest
from reasoning_gym.arithmetic.basic_arithmetic import (
@ -64,11 +62,19 @@ def test_arithmetic_dataset_format_styles():
max_digits=2,
)
dataset = BasicArithmeticDataset(config)
assert all(item["question"].endswith("=") for item in dataset)
assert all(item["question"].strip().endswith(".") for item in dataset)
config.format_style = "natural"
config = BasicArithmeticDatasetConfig(
size=10,
seed=42,
format_style="natural",
min_terms=2,
max_terms=3, # Keep expressions simple for testing
min_digits=1,
max_digits=2,
)
dataset = BasicArithmeticDataset(config)
assert all("=" not in item["question"] for item in dataset)
assert all(item["question"].strip().endswith(".") for item in dataset)
def test_arithmetic_dataset_iteration():

View file

@ -15,6 +15,14 @@ def test_binary_matrix_config_validation():
config = BinaryMatrixConfig(max_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = BinaryMatrixConfig(min_n=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = BinaryMatrixConfig(min_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed
config.validate()
@ -98,3 +106,18 @@ def test_binary_matrix_answer():
# Empty matrix
matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
# String representation of answer
answer = "0 0 0\n0 1 0\n1 2 1"
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 1.0
# Answer is a python list (partially correct answer)
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 0.5
# Answer is null
answer = None
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 0.0

224
tests/test_circuit_logic.py Normal file
View file

@ -0,0 +1,224 @@
import pytest
from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset
def test_circuit_logic_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CircuitLogicConfig(min_inputs=3, max_inputs=2)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(num_terms=0)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(neg_prob=-0.1)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(neg_prob=1.1)
config.validate()
def test_circuit_logic_deterministic():
"""Test that dataset generates same items with same seed"""
config = CircuitLogicConfig(seed=42, size=10)
dataset1 = CircuitLogicDataset(config)
dataset2 = CircuitLogicDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_circuit_logic_items():
"""Test basic properties of generated items"""
config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contents
metadata = item["metadata"]
assert "expression" in metadata
assert "assignments" in metadata
assert "final_gate" in metadata
assert "inputs" in metadata
# Verify answer is binary
assert item["answer"] in ("0", "1")
# Verify assignments are binary
for input_name, value in metadata["assignments"].items():
assert value in (0, 1)
# Verify final gate is valid
assert metadata["final_gate"] in ("OR", "NOR", "XOR", "AND")
# Verify inputs list matches assignments
assert set(metadata["inputs"]) == set(metadata["assignments"].keys())
def test_circuit_logic_expression_validity():
"""Test that generated expressions follow logical circuit rules"""
config = CircuitLogicConfig(
num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing
)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
# Expression should contain valid operators
expr = metadata["expression"]
assert any(op in expr for op in ("&", "", "", "+", ""))
# Input names should be valid Excel-style names
for input_name in metadata["inputs"]:
assert input_name.isalpha()
assert input_name.isupper()
def test_circuit_logic_answer_verification():
"""Test that answers match logical evaluation of circuits"""
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42)
dataset = CircuitLogicDataset(config)
def evaluate_term(term: str, assignments: dict) -> int:
"""Evaluate a single term with given assignments"""
if "" in term: # NAND
parts = term.split("")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return 0 if all(v == 1 for v in values) else 1
elif "&" in term: # AND
parts = term.split("&")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return 1 if all(v == 1 for v in values) else 0
elif "" in term: # XOR
parts = term.split("")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return sum(values) % 2
else:
raise ValueError(f"Unknown operator in term: {term}")
def evaluate_final_gate(gate_type: str, term_values: list) -> int:
"""Evaluate the final gate with given term values"""
if gate_type == "AND":
return 1 if all(v == 1 for v in term_values) else 0
elif gate_type == "OR":
return 1 if any(v == 1 for v in term_values) else 0
elif gate_type == "XOR":
return sum(term_values) % 2
elif gate_type == "NOR":
return 0 if any(v == 1 for v in term_values) else 1
else:
raise ValueError(f"Unknown gate type: {gate_type}")
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
assignments = metadata["assignments"]
final_gate = metadata["final_gate"]
term_strings = metadata["term_strings"]
# First evaluate each term
term_values = [evaluate_term(term, assignments) for term in term_strings]
# Then combine terms with final gate
expected = evaluate_final_gate(final_gate, term_values)
# Compare with actual result
result = int(item["answer"])
assert (
result == expected
), f"Item {i}: Expected {expected} but got {result} for terms {term_strings} with assignments {assignments} and final gate {final_gate}"
def test_circuit_logic_ascii_diagram():
"""Test properties of the ASCII circuit diagram"""
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Split question to get diagram
parts = item["question"].split("\n")
diagram_start = parts.index("Below is a randomly generated logic circuit.") + 2
diagram_end = parts.index("", diagram_start)
diagram = parts[diagram_start:diagram_end]
# Basic diagram validation
assert len(diagram) > 0
assert all(len(row) > 0 for row in diagram)
# Check for required circuit elements
diagram_str = "\n".join(diagram)
assert "OUT" in diagram_str
assert any(gate in diagram_str for gate in ("&", "", ""))
# Verify input labels
for input_name in item["metadata"]["inputs"]:
assert f"{input_name}:" in diagram_str
def test_circuit_logic_scoring():
"""Test the answer scoring mechanism"""
config = CircuitLogicConfig(size=5, seed=42)
dataset = CircuitLogicDataset(config)
item = dataset[0]
# Correct answer should score 1.0
assert dataset.score_answer(item["answer"], item) == 1.0
# Wrong answer should score lower
wrong_answer = "1" if item["answer"] == "0" else "0"
assert dataset.score_answer(wrong_answer, item) < 1.0
# None or empty answer should score 0.0
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.0 # Empty string should score 0.0 like None
def test_circuit_logic_iteration():
"""Test that iteration works correctly"""
config = CircuitLogicConfig(size=5, seed=42)
dataset = CircuitLogicDataset(config)
# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size
# Test list conversion
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items

105
tests/test_cryptarithm.py Normal file
View file

@ -0,0 +1,105 @@
import pytest
from reasoning_gym import create_dataset
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmDataset
def test_cryptarithm_generation():
dataset = create_dataset("cryptarithm", seed=42, size=10)
assert isinstance(dataset, CryptarithmDataset)
unique_number = set()
for item in dataset:
# Check required keys exist
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Validate question format
question = item["question"]
assert "Solve this cryptarithm:" in question
assert "Each letter stands for a unique digit (0-9)" in question
# Validate metadata structure
metadata = item["metadata"]
assert "letters" in metadata
assert "letter_to_digit" in metadata
assert "words_letters" in metadata
assert "result_letters" in metadata
assert "word_values" in metadata
assert "sum_number" in metadata
# Validate letter to digit mapping
letter_to_digit = metadata["letter_to_digit"]
used_digits = set(letter_to_digit.values())
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
# Validate the arithmetic
word_values = metadata["word_values"]
result_value = metadata["sum_number"]
assert sum(word_values) == result_value, "Sum of word values should equal result value"
unique_number.add(result_value)
assert len(unique_number) == len(dataset)
def test_cryptarithm_config():
# Test invalid configs raise assertions
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", size=0) # size must be positive
def test_leading_zero_constraint():
# Test with leading zeros not allowed
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
for item in dataset:
# print(item['question'])
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
words_letters = metadata["words_letters"]
result_letters = metadata["result_letters"]
# Check leading letters of all words and result
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
for letter in leading_letters:
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
def test_deterministic_generation():
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"]
assert dataset1[i]["answer"] == dataset2[i]["answer"]
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_word_length_constraints():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
words_letters = metadata["words_letters"]
# Check each word is between 3-5 letters as specified in the code
for word in words_letters:
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
def test_max_letters_constraint():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
# Check total unique letters doesn't exceed 10 (digits 0-9)
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"

188
tests/test_futoshiki.py Normal file
View file

@ -0,0 +1,188 @@
import pytest
from reasoning_gym.games import FutoshikiConfig, FutoshikiDataset
def test_futoshiki_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = FutoshikiConfig(board_size=3) # Too small
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(board_size=10) # Too large
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(difficulty=-1) # Invalid difficulty
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(difficulty=4) # Invalid difficulty
config.validate()
def test_futoshiki_deterministic():
"""Test that dataset generates same puzzles with same seed"""
config = FutoshikiConfig(seed=42, size=10)
dataset1 = FutoshikiDataset(config)
dataset2 = FutoshikiDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_futoshiki_items():
"""Test basic properties of generated items"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
dataset = FutoshikiDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contents
metadata = item["metadata"]
assert "puzzle" in metadata
assert "solution" in metadata
assert "constraints" in metadata
assert "board_size" in metadata
assert "difficulty" in metadata
# Verify board dimensions
puzzle = metadata["puzzle"]
solution = metadata["solution"]
assert len(puzzle) == config.board_size
assert len(solution) == config.board_size
for row in puzzle:
assert len(row) == config.board_size
for row in solution:
assert len(row) == config.board_size
# Verify constraints format
constraints = metadata["constraints"]
for ((r1, c1), (r2, c2)), rel in constraints.items():
assert 0 <= r1 < config.board_size
assert 0 <= c1 < config.board_size
assert 0 <= r2 < config.board_size
assert 0 <= c2 < config.board_size
assert rel in ("<", ">")
def test_futoshiki_solution_validity():
"""Test that solutions are valid according to Futoshiki rules"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
dataset = FutoshikiDataset(config)
def is_valid_solution(solution, board_size, constraints):
# Check rows
for row in solution:
if sorted(row) != list(range(1, board_size + 1)):
return False
# Check columns
for col in range(board_size):
column = [solution[row][col] for row in range(board_size)]
if sorted(column) != list(range(1, board_size + 1)):
return False
# Check constraints
for ((r1, c1), (r2, c2)), rel in constraints.items():
v1, v2 = solution[r1][c1], solution[r2][c2]
if rel == "<" and not (v1 < v2):
return False
if rel == ">" and not (v1 > v2):
return False
return True
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
solution = metadata["solution"]
constraints = metadata["constraints"]
assert is_valid_solution(solution, config.board_size, constraints)
def test_futoshiki_puzzle_solvability():
"""Test that generated puzzles are solvable and have unique solutions"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=5, seed=42)
dataset = FutoshikiDataset(config)
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
puzzle = metadata["puzzle"]
constraints = metadata["constraints"]
# Verify puzzle has exactly one solution
assert dataset.count_solutions(puzzle, constraints, limit=2) == 1
def test_futoshiki_difficulty_levels():
"""Test that different difficulty levels affect puzzle complexity"""
size = 5
board_size = 4
seeds = [42, 43, 44] # Test multiple seeds for robustness
def count_clues(puzzle):
return sum(cell != 0 for row in puzzle for cell in row)
def count_constraints(constraints):
return len(constraints)
for seed in seeds:
clues_by_difficulty = []
constraints_by_difficulty = []
for difficulty in range(4): # 0 to 3
config = FutoshikiConfig(board_size=board_size, difficulty=difficulty, size=size, seed=seed)
dataset = FutoshikiDataset(config)
avg_clues = sum(count_clues(item["metadata"]["puzzle"]) for item in dataset) / size
avg_constraints = sum(count_constraints(item["metadata"]["constraints"]) for item in dataset) / size
clues_by_difficulty.append(avg_clues)
constraints_by_difficulty.append(avg_constraints)
# Higher difficulty should generally mean fewer clues and/or more constraints
assert all(clues_by_difficulty[i] >= clues_by_difficulty[i + 1] for i in range(len(clues_by_difficulty) - 1))
assert all(
constraints_by_difficulty[i] <= constraints_by_difficulty[i + 1]
for i in range(len(constraints_by_difficulty) - 1)
)
def test_futoshiki_answer_scoring():
"""Test the answer scoring mechanism"""
config = FutoshikiConfig(board_size=4, difficulty=0, size=5, seed=42)
dataset = FutoshikiDataset(config)
for item in dataset:
# Correct answer should score 1.0
assert dataset.score_answer(item["answer"], item) == 1.0
# Wrong answer should score lower
wrong_answer = item["answer"].replace("1", "2")
assert dataset.score_answer(wrong_answer, item) < 1.0
# None or empty answer should score 0.0
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.0
answer = item["answer"]
white_space_mismatch = answer.replace(" ", " ")
assert dataset.score_answer(white_space_mismatch, item) == 0.9
anwser_with_additional_text = "This is an anwser " + answer + "\nwith surrounding text."
assert 0 < dataset.score_answer(anwser_with_additional_text, item) < 0.9
partially_correct = anwser_with_additional_text.replace("1", "2")
assert dataset.score_answer(partially_correct, item) > 0.1
bad_answer = "\n".join(anwser_with_additional_text.split("\n")[::-1])
assert dataset.score_answer(bad_answer, item) < 0.1

View file

@ -122,6 +122,11 @@ def test_nqueens_score_answer():
# Test None answer gets score 0.0
assert dataset.score_answer(None, item) == 0.0
# Test python list representation of board (partial solution)
answer = "[['_', 'Q', '_', '_'], ['_', '_', '_', 'Q'], ['Q', '_', '_', '_'], ['_', '_', 'Q', '_']]"
entry = {"metadata": {"valid_answers": {"_ Q _ _\n_ _ _ Q\nQ _ _ _\n_ _ Q _"}}}
assert dataset.score_answer(answer, entry) == 0.5
def is_valid_solution(board: list[list[str]]) -> bool:
"""Helper function to verify N Queens solution validity"""

View file

@ -0,0 +1,111 @@
"""Tests for Palindrome Partitioning questions generation"""
import json
from reasoning_gym.algorithmic.palindrome_partitioning import (
PalindromePartitioningConfig,
PalindromePartitioningDataset,
)
def test_palindrome_partitioning_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PalindromePartitioningConfig(seed=42, size=10)
dataset1 = PalindromePartitioningDataset(config)
dataset2 = PalindromePartitioningDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_palindrome_partitioning_dataset_items():
"""Test basic properties of generated items"""
config = PalindromePartitioningConfig(size=10, seed=42)
dataset = PalindromePartitioningDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "string" in item["metadata"]
assert "solution" in item["metadata"]
string = item["metadata"]["string"]
solution = item["metadata"]["solution"]
# Verify string is not empty
assert len(string) > 0
# At least one partitioning exists (each letter is a palindrome)
assert len(solution) >= 1
# Verify each partitioning reconstructs the original string
assert all(len(partitioning) > 0 for partitioning in solution)
assert all("".join(partitioning) == string for partitioning in solution)
def test_palindrome_partitioning_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PalindromePartitioningConfig(size=5, seed=42)
dataset = PalindromePartitioningDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_palindrome_partitioning_answer():
"""Test the _palindrome_partitioning method"""
config = PalindromePartitioningConfig(seed=42)
dataset = PalindromePartitioningDataset(config)
# General use case
word = "afternoon"
correct = [
["a", "f", "t", "e", "r", "n", "o", "o", "n"],
["a", "f", "t", "e", "r", "n", "oo", "n"],
["a", "f", "t", "e", "r", "noon"],
]
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
# Single letter word
word = "a"
correct = [["a"]]
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
# Empty string
word = ""
correct = []
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
def test_palindrome_partitioning_score_answer():
"""Test the score_answer method"""
config = PalindromePartitioningConfig(seed=42)
dataset = PalindromePartitioningDataset(config)
# Verify the scoring function is permutation invariant
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"], ["noon"]])
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 1
# Verify the score is 0.01 when incorrect
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0.01
# Verify the score is 0 when answer is None
answer = None
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0
# Verify the score is 0 when answer is malformed JSON
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0

View file

@ -1,3 +1,5 @@
import string
import pytest
import sympy as sp
@ -17,7 +19,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
@ -28,6 +30,17 @@ def test_polynomial_config_validation():
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(variables="").validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(
allow_cross_variable_product=False, allow_multivariate_polynomials=True
).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
def test_polynomial_multiplication_dataset_basic():
"""Test dataset creation and length"""
@ -41,7 +54,9 @@ def test_polynomial_multiplication_dataset_basic():
max_degree=2,
min_polynomials=2,
max_polynomials=3,
single_variable=True,
variables=tuple(string.ascii_lowercase),
allow_cross_variable_product=False,
allow_multivariate_polynomials=False,
seed=42,
size=dataset_size,
)
@ -63,7 +78,9 @@ def test_polynomial_equations_dataset_items():
max_degree=2,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
variables=tuple("xyz"),
allow_cross_variable_product=False,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
@ -75,7 +92,113 @@ def test_polynomial_equations_dataset_items():
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["single_variable"], bool)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_multivariate_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
@ -105,7 +228,9 @@ def test_polynomial_solutions_evaluation():
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=5,
seed=42,
)
@ -125,42 +250,27 @@ def test_score_function():
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=4,
max_terms=3,
min_value=1,
max_value=10,
max_value=3,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=True,
size=1,
min_polynomials=3,
max_polynomials=3,
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
seed=42,
)
assert ds.score_answer(None, ds[0]) == 0.00
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
assert ds.score_answer("x**4", ds[0]) == 0.05
for item in ds:
poly_str = item["metadata"]["polynomial_expr"]
assert ds.score_answer(poly_str, item) == 0.05
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0
def test_multivariate_score_function():
"""Test that solution satisfy the polynomial multiplication."""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=4,
min_value=1,
max_value=10,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
size=1,
seed=42,
)
assert ds.score_answer(None, ds[0]) == 0.00
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
assert ds.score_answer("x**4", ds[0]) == 0.05
assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05

View file

@ -37,6 +37,7 @@ def test_getitem(dataset, config):
assert "metadata" in item
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
assert len(item["answer"].split()) == item["metadata"]["word_count"]
def test_key_error_in_getitem(dataset):

View file

@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation():
config = SpiralMatrixConfig(max_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = SpiralMatrixConfig(max_n=1) # One not allowed
config.validate()
def test_spiral_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
@ -69,18 +73,26 @@ def test_spiral_matrix_answer():
config = SpiralMatrixConfig(seed=42)
dataset = SpiralMatrixDataset(config)
# One element
matrix = [[0]]
assert dataset._get_spiral(matrix) == [0]
# One row
matrix = [[0, 1, 2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# One column
matrix = [[0], [1], [2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# 2D grid
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
# Answer is identical (up to trimming)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "\n\n1 2 3 6 9 8 7 4 5\n"
assert dataset.score_answer(answer, entry) == 1.0
# Score answer in list format (partially correct)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]"
assert dataset.score_answer(answer, entry) == 0.5
# Answer is incorrect
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "1 2 3"
assert dataset.score_answer(answer, entry) == 0.01
# Answer is none
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = None
assert dataset.score_answer(answer, entry) == 0.0

View file

@ -92,3 +92,13 @@ def test_string_insertion_answer():
# No reuse of newly inserted characters
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
# Test score_answer with correct answer
answer = "AABCDAEEEEEEEBCDEBAAAAA"
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
assert dataset.score_answer(answer, entry) == 1.0
# Test score_answer with correct answer as python list of characters (partial correct)
answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']"
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
assert dataset.score_answer(answer, entry) == 0.5

View file

@ -116,3 +116,35 @@ def test_word_sorting_dataset_iteration():
# Test multiple iterations yield same items
assert items == list(dataset)
def test_word_sorting_scoring():
"""Test scoring function"""
config = WordSortingConfig(size=1, seed=42)
dataset = WordSortingDataset(config)
item = {
"metadata": {
"sorted_words": ["apple", "banana", "cherry"],
}
}
# Correct answer
answer = "apple, banana, cherry"
assert dataset.score_answer(answer, item) == 1.0
# Correct answer, with incorrect spaces
answer = "apple,banana, cherry"
assert dataset.score_answer(answer, item) == 1.0
# All words present, but not sorted
answer = "banana, cherry, apple"
assert dataset.score_answer(answer, item) == 0.2
# Garbage
answer = "gibberish"
assert dataset.score_answer(answer, item) == 0.01
# Empty answer
answer = None
assert dataset.score_answer(answer, item) == 0.0