feat: Add special case handling for min_digits=1 in ChainSum generation

This commit is contained in:
Andreas Koepf (aider) 2025-01-23 12:07:56 +01:00
parent 4777e6b435
commit 516d4d20d4
2 changed files with 35 additions and 1 deletions

View file

@ -53,6 +53,7 @@ def test_chain_sum_items():
def test_chain_sum_number_ranges():
"""Test that generated numbers respect digit constraints"""
# Test 3-digit numbers
config = ChainSumConfig(
min_terms=2,
max_terms=2, # Fix to 2 terms for easier testing
@ -63,6 +64,27 @@ def test_chain_sum_number_ranges():
)
dataset = ChainSum(config)
for i in range(len(dataset)):
item = dataset[i]
expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit()]
for num in numbers:
if config.allow_negation:
assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits"
else:
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
# Test 1-digit numbers
config = ChainSumConfig(
min_terms=2,
max_terms=2,
min_digits=1,
max_digits=1,
size=50,
seed=42
)
dataset = ChainSum(config)
for i in range(len(dataset)):
item = dataset[i]
expression = item["metadata"]["expression"]
@ -76,6 +98,18 @@ def test_chain_sum_number_ranges():
assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits"
else:
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
# Test 1-digit numbers
dataset = ChainSum(config)
for i in range(len(dataset)):
item = dataset[i]
expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit()]
for num in numbers:
if config.allow_negation:
assert -9 <= num <= 9, f"Number {num} outside valid range for 1 digit"
else:
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
def test_chain_sum_negation():
"""Test that allow_negation controls number ranges"""