mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
feat: Add special case handling for min_digits=1 in ChainSum generation
This commit is contained in:
parent
4777e6b435
commit
516d4d20d4
2 changed files with 35 additions and 1 deletions
|
|
@ -53,6 +53,7 @@ def test_chain_sum_items():
|
|||
|
||||
def test_chain_sum_number_ranges():
|
||||
"""Test that generated numbers respect digit constraints"""
|
||||
# Test 3-digit numbers
|
||||
config = ChainSumConfig(
|
||||
min_terms=2,
|
||||
max_terms=2, # Fix to 2 terms for easier testing
|
||||
|
|
@ -63,6 +64,27 @@ def test_chain_sum_number_ranges():
|
|||
)
|
||||
dataset = ChainSum(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
expression = item["metadata"]["expression"]
|
||||
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||
for num in numbers:
|
||||
if config.allow_negation:
|
||||
assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||
else:
|
||||
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||
|
||||
# Test 1-digit numbers
|
||||
config = ChainSumConfig(
|
||||
min_terms=2,
|
||||
max_terms=2,
|
||||
min_digits=1,
|
||||
max_digits=1,
|
||||
size=50,
|
||||
seed=42
|
||||
)
|
||||
dataset = ChainSum(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
expression = item["metadata"]["expression"]
|
||||
|
|
@ -76,6 +98,18 @@ def test_chain_sum_number_ranges():
|
|||
assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||
else:
|
||||
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||
|
||||
# Test 1-digit numbers
|
||||
dataset = ChainSum(config)
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
expression = item["metadata"]["expression"]
|
||||
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||
for num in numbers:
|
||||
if config.allow_negation:
|
||||
assert -9 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||
else:
|
||||
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||
|
||||
def test_chain_sum_negation():
|
||||
"""Test that allow_negation controls number ranges"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue