diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 18519c93..e9a02d66 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -18,6 +18,9 @@ class ChainSumConfig: size: int = 500 def validate(self) -> None: + """Validate configuration parameters""" + assert self.difficulty > 0, "difficulty must be positive" + assert self.size > 0, "size must be positive" """Validate configuration parameters""" assert self.min_terms > 0, "min_terms must be positive" assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" diff --git a/tests/test_quantum_lock.py b/tests/test_quantum_lock.py index 9c3774d5..a004afc9 100644 --- a/tests/test_quantum_lock.py +++ b/tests/test_quantum_lock.py @@ -1,13 +1,31 @@ -from magiccube.cube import Cube +import pytest from reasoning_gym.graphs.quantum_lock import QuantumLockConfig, QuantumLockDataset +def test_quantumlock_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = QuantumLockConfig(difficulty=-1) + config.validate() + + with pytest.raises(AssertionError): + config = QuantumLockConfig(size=0) + config.validate() + + +def test_quantumlock_deterministic(): + """Test that dataset generates same items with same seed""" + config = QuantumLockConfig(seed=42, size=10) + dataset1 = QuantumLockDataset(config) + dataset2 = QuantumLockDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + def test_quantumlock_items(): """Test basic properties and solution of generated items""" - config = QuantumLockConfig( - difficulty=10, - size=25, - ) + config = QuantumLockConfig(difficulty=10, size=25) dataset = QuantumLockDataset(config) for item in dataset: @@ -19,6 +37,82 @@ def test_quantumlock_items(): # Check metadata contains required fields assert "solution_path" in item["metadata"] assert "difficulty" in item["metadata"] + assert "buttons" in item["metadata"] + assert "initial_state" in item["metadata"] + assert "target_value" in item["metadata"] + # Verify solution works assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_quantumlock_button_states(): + """Test button state transitions and validity""" + config = QuantumLockConfig(difficulty=5, size=10) + dataset = QuantumLockDataset(config) + + for item in dataset: + buttons = item["metadata"]["buttons"] + + # Check button properties + for btn in buttons: + assert "name" in btn + assert "type" in btn + assert "value" in btn + assert "active_state" in btn + + # Verify button name format + assert btn["name"] in ["A", "B", "C"] + + # Verify operation type + assert btn["type"] in ["add", "subtract", "multiply"] + + # Verify state constraints + assert btn["active_state"] in ["red", "green", "any"] + + +def test_quantumlock_solution_validation(): + """Test solution validation and simulation""" + config = QuantumLockConfig(difficulty=5, size=10) + dataset = QuantumLockDataset(config) + + for item in dataset: + solution = item["metadata"]["solution_path"] + target = item["metadata"]["target_value"] + + # Test solution simulation + final_value = dataset.simulate_sequence( + item["metadata"], + solution + ) + assert final_value == target + + # Test invalid button sequences + assert dataset.simulate_sequence( + item["metadata"], + ["X", "Y", "Z"] # Invalid buttons + ) == item["metadata"]["initial_value"] + + +def test_quantumlock_scoring(): + """Test score calculation for various answers""" + config = QuantumLockConfig(difficulty=5, size=10) + dataset = QuantumLockDataset(config) + + for item in dataset: + solution = item["metadata"]["solution_path"] + + # Test correct solution + assert dataset.score_answer(solution, item) == 1.0 + + # Test empty/None answers + assert dataset.score_answer(None, item) == 0.0 + assert dataset.score_answer("", item) == 0.1 + + # Test invalid buttons + assert dataset.score_answer("XYZ", item) == 0.1 + + # Test case insensitivity + if solution: + lower_solution = "".join(solution).lower() + assert dataset.score_answer(lower_solution, item) == 1.0