mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
fix: Remove operators configuration from simple equations dataset
This commit is contained in:
parent
940b810a43
commit
cf2434b3aa
2 changed files with 3 additions and 35 deletions
|
|
@ -17,7 +17,6 @@ class SimpleEquationsConfig:
|
||||||
max_terms: int = 4 # Maximum number of terms
|
max_terms: int = 4 # Maximum number of terms
|
||||||
min_value: int = 1 # Minimum value for constants
|
min_value: int = 1 # Minimum value for constants
|
||||||
max_value: int = 100 # Maximum value for constants
|
max_value: int = 100 # Maximum value for constants
|
||||||
operators: tuple = ("+", "-", "*") # Allowed operators
|
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
|
|
@ -27,7 +26,6 @@ class SimpleEquationsConfig:
|
||||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||||
assert self.min_value > 0, "min_value must be positive"
|
assert self.min_value > 0, "min_value must be positive"
|
||||||
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
||||||
assert len(self.operators) > 0, "must specify at least one operator"
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleEquationsDataset(ProceduralDataset):
|
class SimpleEquationsDataset(ProceduralDataset):
|
||||||
|
|
@ -97,16 +95,13 @@ class SimpleEquationsDataset(ProceduralDataset):
|
||||||
|
|
||||||
# Replace one random term with the variable term
|
# Replace one random term with the variable term
|
||||||
var_pos = rng.randint(0, num_terms - 1)
|
var_pos = rng.randint(0, num_terms - 1)
|
||||||
if "*" in self.config.operators:
|
|
||||||
coef = rng.randint(self.config.min_value, self.config.max_value)
|
coef = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
terms[var_pos] = coef * x
|
terms[var_pos] = coef * x
|
||||||
else:
|
|
||||||
terms[var_pos] = x
|
|
||||||
|
|
||||||
# Apply operators between terms
|
# Apply operators between terms
|
||||||
expr = terms[0]
|
expr = terms[0]
|
||||||
for i in range(1, num_terms):
|
for i in range(1, num_terms):
|
||||||
op = rng.choice(self.config.operators)
|
op = rng.choice(("+", "-", "*"))
|
||||||
if op == "+":
|
if op == "+":
|
||||||
expr = expr + terms[i]
|
expr = expr + terms[i]
|
||||||
elif op == "-":
|
elif op == "-":
|
||||||
|
|
@ -135,7 +130,6 @@ def simple_equations_dataset(
|
||||||
max_terms: int = 5,
|
max_terms: int = 5,
|
||||||
min_value: int = 1,
|
min_value: int = 1,
|
||||||
max_value: int = 100,
|
max_value: int = 100,
|
||||||
operators: tuple = ("+", "-", "*"),
|
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
size: int = 500,
|
size: int = 500,
|
||||||
) -> SimpleEquationsDataset:
|
) -> SimpleEquationsDataset:
|
||||||
|
|
@ -145,7 +139,6 @@ def simple_equations_dataset(
|
||||||
max_terms=max_terms,
|
max_terms=max_terms,
|
||||||
min_value=min_value,
|
min_value=min_value,
|
||||||
max_value=max_value,
|
max_value=max_value,
|
||||||
operators=operators,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
size=size,
|
size=size,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,6 @@ def test_simple_equations_config_validation():
|
||||||
config = SimpleEquationsConfig(min_value=100, max_value=50) # max < min value
|
config = SimpleEquationsConfig(min_value=100, max_value=50) # max < min value
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = SimpleEquationsConfig(operators=()) # Empty operators
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_equations_dataset_deterministic():
|
def test_simple_equations_dataset_deterministic():
|
||||||
|
|
@ -112,25 +109,3 @@ def test_simple_equations_solution_verification():
|
||||||
assert evaluated == right_side
|
assert evaluated == right_side
|
||||||
|
|
||||||
|
|
||||||
def test_simple_equations_operators():
|
|
||||||
"""Test equation generation with different operator combinations"""
|
|
||||||
for operators in [
|
|
||||||
("+",),
|
|
||||||
("+", "-"),
|
|
||||||
("*",),
|
|
||||||
("+", "*"),
|
|
||||||
("+", "-", "*"),
|
|
||||||
]:
|
|
||||||
config = SimpleEquationsConfig(
|
|
||||||
operators=operators,
|
|
||||||
size=5,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SimpleEquationsDataset(config)
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
equation = item["metadata"]["equation"]
|
|
||||||
# Verify only allowed operators are used
|
|
||||||
for op in "+-*":
|
|
||||||
if op in equation:
|
|
||||||
assert op in operators, str(equation)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue