mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
fix: Prevent infinite loop in arithmetic dataset space generation
This commit is contained in:
parent
40596262e1
commit
fbba398c91
2 changed files with 9 additions and 5 deletions
|
|
@ -14,7 +14,7 @@ class ArithmeticDatasetConfig:
|
||||||
allow_parentheses: bool = True
|
allow_parentheses: bool = True
|
||||||
allow_negation: bool = True
|
allow_negation: bool = True
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 10000 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
format_style: Literal["simple", "natural"] = "simple"
|
format_style: Literal["simple", "natural"] = "simple"
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
|
@ -102,14 +102,14 @@ class ArithmeticDataset:
|
||||||
|
|
||||||
add_terms(num_terms)
|
add_terms(num_terms)
|
||||||
|
|
||||||
# Add random spaces
|
# Add at most one random space between parts
|
||||||
space_parts = []
|
space_parts = []
|
||||||
for p in parts:
|
for p in parts:
|
||||||
while rng.random() < 0.15:
|
if rng.random() < 0.15:
|
||||||
space_parts.append(" ")
|
space_parts.append(" ")
|
||||||
space_parts.append(p)
|
space_parts.append(p)
|
||||||
|
|
||||||
expression = " ".join(space_parts)
|
expression = " ".join(space_parts).strip()
|
||||||
result = eval(expression) # Note: eval is safe here as we control the input
|
result = eval(expression) # Note: eval is safe here as we control the input
|
||||||
|
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,11 @@ def test_arithmetic_dataset_format_styles():
|
||||||
config = ArithmeticDatasetConfig(
|
config = ArithmeticDatasetConfig(
|
||||||
size=10,
|
size=10,
|
||||||
seed=42,
|
seed=42,
|
||||||
format_style="simple"
|
format_style="simple",
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=3, # Keep expressions simple for testing
|
||||||
|
min_digits=1,
|
||||||
|
max_digits=2
|
||||||
)
|
)
|
||||||
dataset = ArithmeticDataset(config)
|
dataset = ArithmeticDataset(config)
|
||||||
assert all(item["question"].endswith("=") for item in dataset)
|
assert all(item["question"].endswith("=") for item in dataset)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue