diff --git a/README.md b/README.md index 4b91500c..8b5ef4f6 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,9 @@ The goal is to generate virtually infinite data with adjustable complexity. #### Basic Arithmetic Generates arithmetic problems with configurable complexity: ```python -from reasoning_gym.arithmetic import ArithmeticDataset, ArithmeticDatasetConfig +from reasoning_gym.arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig -config = ArithmeticDatasetConfig( +config = BasicArithmeticDatasetConfig( min_terms=2, # Minimum number of terms in expression max_terms=4, # Maximum number of terms min_digits=1, # Minimum digits per number @@ -49,7 +49,7 @@ config = ArithmeticDatasetConfig( seed=42 # For reproducibility ) -dataset = ArithmeticDataset(config) +dataset = BasicArithmeticDataset(config) for item in dataset: print(item) ``` diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 83b9ac0b..5264a9b0 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -6,7 +6,7 @@ Algorithmic tasks for training reasoning capabilities: - Pattern matching """ -from reasoning_gym.arithmetic.basic_arithmetic import arithmetic_dataset +from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset @@ -15,7 +15,7 @@ from .number_sorting import NumberSortingConfig, NumberSortingDataset, number_so from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset __all__ = [ - "arithmetic_dataset", + "basic_arithmetic_dataset", "BaseConversionConfig", "BaseConversionDataset", "base_conversion_dataset", diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index c7953db8..cf93144a 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,7 +6,7 @@ Arithmetic tasks for training reasoning capabilities: - Leg counting """ -from .basic_arithmetic import BasicArithmeticDataset, ArithmeticDatasetConfig, basic_arithmetic_dataset +from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset from .gcd import GCDConfig, GCDDataset, gcd_dataset @@ -16,7 +16,7 @@ from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDat __all__ = [ "BasicArithmeticDataset", - "ArithmeticDatasetConfig", + "BasicArithmeticDatasetConfig", "basic_arithmetic_dataset", "ChainSum", "ChainSumConfig", diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 80a45edb..ae9750f6 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -5,7 +5,7 @@ from ..dataset import ProceduralDataset @dataclass -class ArithmeticDatasetConfig: +class BasicArithmeticDatasetConfig: """Configuration for arithmetic dataset generation""" min_terms: int = 2 @@ -34,7 +34,7 @@ class ArithmeticDatasetConfig: class BasicArithmeticDataset(ProceduralDataset): """Dataset that generates basic arithmetic tasks with configurable complexity""" - def __init__(self, config: ArithmeticDatasetConfig): + def __init__(self, config: BasicArithmeticDatasetConfig): self.config = config self.config.validate() super().__init__(seed=config.seed, size=config.size) @@ -183,9 +183,9 @@ def basic_arithmetic_dataset( format_style: Style of question formatting ("simple" or "natural") Returns: - ArithmeticDataset: Configured dataset instance + BasicArithmeticDataset: Configured dataset instance """ - config = ArithmeticDatasetConfig( + config = BasicArithmeticDatasetConfig( min_terms=min_terms, max_terms=max_terms, min_digits=min_digits, diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 87541e8b..2aa7bd19 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -19,7 +19,7 @@ class FractionSimplificationConfig: def validate(self): """Validate configuration parameters""" - assert self.min_value >= 0, "min_value must be positive" + assert self.min_value > 0, "min_value must be positive" assert self.max_value > self.min_value, "max_value must be > min_value" assert self.min_factor >= 1, "min_factor must be at least 1" assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index d95e88db..87d5f1f5 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -1,28 +1,28 @@ import pytest from random import Random -from reasoning_gym.arithmetic.basic_arithmetic import ArithmeticDataset, ArithmeticDatasetConfig +from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig def test_arithmetic_dataset_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = ArithmeticDatasetConfig(min_terms=0) + config = BasicArithmeticDatasetConfig(min_terms=0) config.validate() with pytest.raises(AssertionError): - config = ArithmeticDatasetConfig(min_terms=3, max_terms=2) + config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2) config.validate() with pytest.raises(AssertionError): - config = ArithmeticDatasetConfig(operators=["^"]) # Invalid operator + config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator config.validate() def test_arithmetic_dataset_deterministic(): """Test that dataset generates same items with same seed""" - config = ArithmeticDatasetConfig(seed=42, size=10) - dataset1 = ArithmeticDataset(config) - dataset2 = ArithmeticDataset(config) + config = BasicArithmeticDatasetConfig(seed=42, size=10) + dataset1 = BasicArithmeticDataset(config) + dataset2 = BasicArithmeticDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -30,7 +30,7 @@ def test_arithmetic_dataset_deterministic(): def test_arithmetic_dataset_items(): """Test basic properties of generated items""" - config = ArithmeticDatasetConfig( + config = BasicArithmeticDatasetConfig( min_terms=2, max_terms=4, min_digits=1, @@ -38,7 +38,7 @@ def test_arithmetic_dataset_items(): size=100, seed=42 ) - dataset = ArithmeticDataset(config) + dataset = BasicArithmeticDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -55,7 +55,7 @@ def test_arithmetic_dataset_items(): def test_arithmetic_dataset_format_styles(): """Test different question format styles""" - config = ArithmeticDatasetConfig( + config = BasicArithmeticDatasetConfig( size=10, seed=42, format_style="simple", @@ -64,23 +64,23 @@ def test_arithmetic_dataset_format_styles(): min_digits=1, max_digits=2 ) - dataset = ArithmeticDataset(config) + dataset = BasicArithmeticDataset(config) assert all(item["question"].endswith("=") for item in dataset) config.format_style = "natural" - dataset = ArithmeticDataset(config) + dataset = BasicArithmeticDataset(config) assert all("=" not in item["question"] for item in dataset) def test_arithmetic_dataset_iteration(): """Test that iteration respects dataset size""" - config = ArithmeticDatasetConfig( + config = BasicArithmeticDatasetConfig( min_terms=2, max_terms=2, size=5, # Small size for testing seed=42 ) - dataset = ArithmeticDataset(config) + dataset = BasicArithmeticDataset(config) # Test manual iteration items = [] diff --git a/tests/test_fraction_simplification.py b/tests/test_fraction_simplification.py index ee77e74e..ae674329 100644 --- a/tests/test_fraction_simplification.py +++ b/tests/test_fraction_simplification.py @@ -14,7 +14,7 @@ def test_fraction_config_validation(): config.validate() with pytest.raises(AssertionError): - config = FractionSimplificationConfig(min_factor=1) # Should be >= 2 + config = FractionSimplificationConfig(min_factor=0) # Should be >= 1 config.validate() with pytest.raises(AssertionError):