mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
feat: Add ProductsDataset with configurable terms and digits
This commit is contained in:
parent
3ead141db5
commit
bdcaeff42a
3 changed files with 258 additions and 0 deletions
|
|
@ -14,6 +14,7 @@ from .lcm import LCMConfig, LCMDataset
|
||||||
from .leg_counting import LegCountingConfig, LegCountingDataset
|
from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||||
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
||||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
|
from .products import Products, ProductsConfig
|
||||||
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -35,6 +36,8 @@ __all__ = [
|
||||||
"PowerFunctionDataset",
|
"PowerFunctionDataset",
|
||||||
"PrimeFactorizationConfig",
|
"PrimeFactorizationConfig",
|
||||||
"PrimeFactorizationDataset",
|
"PrimeFactorizationDataset",
|
||||||
|
"Products",
|
||||||
|
"ProductsConfig",
|
||||||
"GSMSymbolicDatasetConfig",
|
"GSMSymbolicDatasetConfig",
|
||||||
"GSMSymbolicDataset",
|
"GSMSymbolicDataset",
|
||||||
"TimeIntervalsConfig",
|
"TimeIntervalsConfig",
|
||||||
|
|
|
||||||
130
reasoning_gym/arithmetic/products.py
Normal file
130
reasoning_gym/arithmetic/products.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProductsConfig:
|
||||||
|
"""Configuration for products task generation"""
|
||||||
|
|
||||||
|
min_terms: int = 2
|
||||||
|
max_terms: int = 2
|
||||||
|
min_digits: int = 1
|
||||||
|
max_digits: int = 5
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.size > 0, "size must be positive"
|
||||||
|
assert self.min_terms > 0, "min_terms must be positive"
|
||||||
|
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||||
|
assert self.min_digits > 0, "min_digits must be positive"
|
||||||
|
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
|
||||||
|
|
||||||
|
|
||||||
|
class Products(ProceduralDataset):
|
||||||
|
"""Generates multiplication tasks with configurable number of terms"""
|
||||||
|
|
||||||
|
def __init__(self, config: ProductsConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single multiplication task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Index of the item to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the formatted multiplication expression
|
||||||
|
- answer: str, the ground truth result
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
# Create deterministic RNG from base seed and idx
|
||||||
|
item_rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
|
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
||||||
|
|
||||||
|
# Calculate value ranges based on number of digits
|
||||||
|
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
|
||||||
|
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
|
||||||
|
|
||||||
|
expression, result = self._generate_task(item_rng, num_terms, min_value, max_value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": f"{expression} =",
|
||||||
|
"answer": str(result),
|
||||||
|
"metadata": {
|
||||||
|
"difficulty": {
|
||||||
|
"num_terms": num_terms,
|
||||||
|
"num_digits": num_digits,
|
||||||
|
},
|
||||||
|
"expression": expression,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
||||||
|
"""Generate a multiplication task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rng: Random number generator
|
||||||
|
num_terms: Number of terms in the expression
|
||||||
|
min_value: Minimum value for generated numbers
|
||||||
|
max_value: Maximum value for generated numbers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (expression string, result integer)
|
||||||
|
"""
|
||||||
|
# Generate random numbers within the specified range
|
||||||
|
constants = [rng.randint(min_value, max_value) for _ in range(num_terms)]
|
||||||
|
|
||||||
|
# Build expression and compute result
|
||||||
|
expression_parts = []
|
||||||
|
result = constants[0]
|
||||||
|
|
||||||
|
expression_parts.append(str(constants[0]))
|
||||||
|
for i in range(1, len(constants)):
|
||||||
|
expression_parts.append("*")
|
||||||
|
expression_parts.append(str(constants[i]))
|
||||||
|
result *= constants[i]
|
||||||
|
|
||||||
|
expression = " ".join(expression_parts)
|
||||||
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
|
class ProductsCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ProductsCurriculum.__name__, ProductsConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_terms",
|
||||||
|
levels=[2, 3, 4, 5],
|
||||||
|
default_level=0, # Start with 2 terms
|
||||||
|
description="Maximum number of terms in the expression",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=2, # Ensure at least 2 terms
|
||||||
|
lower_field_name="min_terms",
|
||||||
|
upper_field_name="max_terms",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_digits",
|
||||||
|
levels=[1, 2, 3, 4],
|
||||||
|
default_level=0, # Start with 1-digit numbers
|
||||||
|
description="Number of digits in each operand",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=1, # Ensure numbers are at least 1 digit
|
||||||
|
lower_field_name="min_digits",
|
||||||
|
upper_field_name="max_digits",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Register the dataset
|
||||||
|
register_dataset("products", Products, ProductsConfig)
|
||||||
125
tests/test_products.py
Normal file
125
tests/test_products.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import Products, ProductsConfig
|
||||||
|
from reasoning_gym.arithmetic.products import ProductsCurriculum
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = ProductsConfig(min_terms=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = ProductsConfig(min_terms=3, max_terms=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = ProductsConfig(seed=42, size=10)
|
||||||
|
dataset1 = Products(config)
|
||||||
|
dataset2 = Products(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
|
dataset = Products(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify only * is used
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
assert all(op in ["*", " "] or op.isdigit() for op in expression)
|
||||||
|
|
||||||
|
# Verify the answer matches the expression
|
||||||
|
answer = eval(expression) # Safe here as we control the expression
|
||||||
|
assert str(answer) == item["answer"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_number_ranges():
|
||||||
|
"""Test that generated numbers respect digit constraints"""
|
||||||
|
# Test 3-digit numbers
|
||||||
|
config = ProductsConfig(
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=2, # Fix to 2 terms for easier testing
|
||||||
|
min_digits=3, # Should generate numbers >= 100
|
||||||
|
max_digits=3, # Should generate numbers <= 999
|
||||||
|
size=50,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dataset = Products(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||||
|
for num in numbers:
|
||||||
|
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||||
|
|
||||||
|
# Test 1-digit numbers
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
||||||
|
dataset = Products(config)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||||
|
for num in numbers:
|
||||||
|
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_iteration():
|
||||||
|
"""Test that iteration respects dataset size"""
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
|
dataset = Products(config)
|
||||||
|
|
||||||
|
# Test manual iteration
|
||||||
|
items = []
|
||||||
|
for item in dataset:
|
||||||
|
items.append(item)
|
||||||
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
|
# Test list conversion
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
|
# Test multiple iterations
|
||||||
|
first_items = list(dataset)
|
||||||
|
second_items = list(dataset)
|
||||||
|
assert first_items == second_items, "Multiple iterations should yield same items"
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_curriculum():
|
||||||
|
curriculum = ProductsCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: ProductsConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
|
||||||
|
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
|
||||||
|
|
||||||
|
# test incrementing attribute levels for num_terms & num_digits attributes
|
||||||
|
curriculum.increment_attr_level("num_terms")
|
||||||
|
curriculum.increment_attr_level("num_digits")
|
||||||
|
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
|
||||||
|
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
|
||||||
|
|
||||||
|
# test decrementing attribute level for num_digits again
|
||||||
|
curriculum.decrement_attr_level("num_digits")
|
||||||
|
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
|
||||||
|
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
|
||||||
Loading…
Add table
Add a link
Reference in a new issue