fix: Prevent infinite loop in arithmetic dataset space generation

This commit is contained in:
Andreas Koepf (aider) 2025-01-23 11:34:18 +01:00
parent 40596262e1
commit fbba398c91
2 changed files with 9 additions and 5 deletions

View file

@ -14,7 +14,7 @@ class ArithmeticDatasetConfig:
allow_parentheses: bool = True allow_parentheses: bool = True
allow_negation: bool = True allow_negation: bool = True
seed: Optional[int] = None seed: Optional[int] = None
size: int = 10000 # Virtual dataset size size: int = 500 # Virtual dataset size
format_style: Literal["simple", "natural"] = "simple" format_style: Literal["simple", "natural"] = "simple"
def validate(self): def validate(self):
@ -102,14 +102,14 @@ class ArithmeticDataset:
add_terms(num_terms) add_terms(num_terms)
# Add random spaces # Add at most one random space between parts
space_parts = [] space_parts = []
for p in parts: for p in parts:
while rng.random() < 0.15: if rng.random() < 0.15:
space_parts.append(" ") space_parts.append(" ")
space_parts.append(p) space_parts.append(p)
expression = " ".join(space_parts) expression = " ".join(space_parts).strip()
result = eval(expression) # Note: eval is safe here as we control the input result = eval(expression) # Note: eval is safe here as we control the input
return expression, result return expression, result

View file

@ -58,7 +58,11 @@ def test_arithmetic_dataset_format_styles():
config = ArithmeticDatasetConfig( config = ArithmeticDatasetConfig(
size=10, size=10,
seed=42, seed=42,
format_style="simple" format_style="simple",
min_terms=2,
max_terms=3, # Keep expressions simple for testing
min_digits=1,
max_digits=2
) )
dataset = ArithmeticDataset(config) dataset = ArithmeticDataset(config)
assert all(item["question"].endswith("=") for item in dataset) assert all(item["question"].endswith("=") for item in dataset)