diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index f6c27256..ba0c4c01 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -17,7 +17,6 @@ class SimpleEquationsConfig: max_terms: int = 4 # Maximum number of terms min_value: int = 1 # Minimum value for constants max_value: int = 100 # Maximum value for constants - operators: tuple = ("+", "-", "*") # Allowed operators seed: Optional[int] = None size: int = 500 @@ -27,7 +26,6 @@ class SimpleEquationsConfig: assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" assert self.min_value > 0, "min_value must be positive" assert self.max_value >= self.min_value, "max_value must be >= min_value" - assert len(self.operators) > 0, "must specify at least one operator" class SimpleEquationsDataset(ProceduralDataset): @@ -97,16 +95,13 @@ class SimpleEquationsDataset(ProceduralDataset): # Replace one random term with the variable term var_pos = rng.randint(0, num_terms - 1) - if "*" in self.config.operators: - coef = rng.randint(self.config.min_value, self.config.max_value) - terms[var_pos] = coef * x - else: - terms[var_pos] = x + coef = rng.randint(self.config.min_value, self.config.max_value) + terms[var_pos] = coef * x # Apply operators between terms expr = terms[0] for i in range(1, num_terms): - op = rng.choice(self.config.operators) + op = rng.choice(("+", "-", "*")) if op == "+": expr = expr + terms[i] elif op == "-": @@ -135,7 +130,6 @@ def simple_equations_dataset( max_terms: int = 5, min_value: int = 1, max_value: int = 100, - operators: tuple = ("+", "-", "*"), seed: Optional[int] = None, size: int = 500, ) -> SimpleEquationsDataset: @@ -145,7 +139,6 @@ def simple_equations_dataset( max_terms=max_terms, min_value=min_value, max_value=max_value, - operators=operators, seed=seed, size=size, ) diff --git a/tests/test_simple_equations.py b/tests/test_simple_equations.py index d4ba8740..f48b5511 100644 --- a/tests/test_simple_equations.py +++ b/tests/test_simple_equations.py @@ -23,9 +23,6 @@ def test_simple_equations_config_validation(): config = SimpleEquationsConfig(min_value=100, max_value=50) # max < min value config.validate() - with pytest.raises(AssertionError): - config = SimpleEquationsConfig(operators=()) # Empty operators - config.validate() def test_simple_equations_dataset_deterministic(): @@ -112,25 +109,3 @@ def test_simple_equations_solution_verification(): assert evaluated == right_side -def test_simple_equations_operators(): - """Test equation generation with different operator combinations""" - for operators in [ - ("+",), - ("+", "-"), - ("*",), - ("+", "*"), - ("+", "-", "*"), - ]: - config = SimpleEquationsConfig( - operators=operators, - size=5, - seed=42 - ) - dataset = SimpleEquationsDataset(config) - - for item in dataset: - equation = item["metadata"]["equation"] - # Verify only allowed operators are used - for op in "+-*": - if op in equation: - assert op in operators, str(equation)