diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 73172318..e8394434 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -17,6 +17,7 @@ class SimpleEquationsConfig: max_terms: int = 4 # Maximum number of terms min_value: int = 1 # Minimum value for constants max_value: int = 100 # Maximum value for constants + operators: tuple = ("+", "-", "*") # Allowed operators seed: Optional[int] = None size: int = 500 @@ -26,6 +27,8 @@ class SimpleEquationsConfig: assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" assert self.min_value > 0, "min_value must be positive" assert self.max_value >= self.min_value, "max_value must be >= min_value" + assert len(self.operators) > 0, "must specify at least one operator" + assert all(op in ("+", "-", "*") for op in self.operators), "invalid operator specified" class SimpleEquationsDataset(ProceduralDataset): @@ -100,7 +103,7 @@ class SimpleEquationsDataset(ProceduralDataset): # Apply operators between terms expr = terms[0] for i in range(1, num_terms): - op = rng.choice(("+", "-", "*")) + op = rng.choice(self.config.operators) if op == "+": expr = expr + terms[i] elif op == "-": @@ -119,6 +122,7 @@ def simple_equations_dataset( max_terms: int = 5, min_value: int = 1, max_value: int = 100, + operators: tuple = ("+", "-", "*"), seed: Optional[int] = None, size: int = 500, ) -> SimpleEquationsDataset: @@ -128,6 +132,7 @@ def simple_equations_dataset( max_terms=max_terms, min_value=min_value, max_value=max_value, + operators=operators, seed=seed, size=size, ) diff --git a/tests/test_simple_equations.py b/tests/test_simple_equations.py index 4785d2f4..efb32133 100644 --- a/tests/test_simple_equations.py +++ b/tests/test_simple_equations.py @@ -23,6 +23,14 @@ def test_simple_equations_config_validation(): config = SimpleEquationsConfig(min_value=100, max_value=50) # max < min value config.validate() + with pytest.raises(AssertionError): + config = SimpleEquationsConfig(operators=()) # Empty operators + config.validate() + + with pytest.raises(AssertionError): + config = SimpleEquationsConfig(operators=("+", "^")) # Invalid operator + config.validate() + def test_simple_equations_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -76,7 +84,13 @@ def test_simple_equations_dataset_iteration(): def test_simple_equations_solution_verification(): """Test that generated equations have correct solutions""" config = SimpleEquationsConfig( - min_terms=2, max_terms=3, min_value=1, max_value=10, size=10, seed=42 # Small values for predictable results + min_terms=2, + max_terms=3, + min_value=1, + max_value=10, # Small values for predictable results + operators=("+", "-"), # Simple operators for easy verification + size=10, + seed=42, ) dataset = SimpleEquationsDataset(config) @@ -94,3 +108,25 @@ def test_simple_equations_solution_verification(): # Replace variable with solution evaluated = eval(left_side.replace(variable, str(solution))) assert evaluated == right_side +def test_simple_equations_operators(): + """Test equation generation with different operator combinations""" + for operators in [ + ("+",), + ("+", "-"), + ("*",), + ("+", "*"), + ("+", "-", "*"), + ]: + config = SimpleEquationsConfig( + operators=operators, + size=5, + seed=42 + ) + dataset = SimpleEquationsDataset(config) + + for item in dataset: + equation = item["metadata"]["equation"] + # Verify only allowed operators are used + for op in "+-*": + if op in equation: + assert op in operators, str(equation)