Rename ArithmeticDataset to BasicArithmeticDataset

This commit is contained in:
Andreas Koepf 2025-01-24 10:31:26 +01:00
parent 44fd0d4a25
commit 98988c8481
7 changed files with 27 additions and 27 deletions

View file

@ -6,7 +6,7 @@ Arithmetic tasks for training reasoning capabilities:
- Leg counting
"""
from .basic_arithmetic import BasicArithmeticDataset, ArithmeticDatasetConfig, basic_arithmetic_dataset
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset
from .gcd import GCDConfig, GCDDataset, gcd_dataset
@ -16,7 +16,7 @@ from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDat
__all__ = [
"BasicArithmeticDataset",
"ArithmeticDatasetConfig",
"BasicArithmeticDatasetConfig",
"basic_arithmetic_dataset",
"ChainSum",
"ChainSumConfig",

View file

@ -5,7 +5,7 @@ from ..dataset import ProceduralDataset
@dataclass
class ArithmeticDatasetConfig:
class BasicArithmeticDatasetConfig:
"""Configuration for arithmetic dataset generation"""
min_terms: int = 2
@ -34,7 +34,7 @@ class ArithmeticDatasetConfig:
class BasicArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
def __init__(self, config: ArithmeticDatasetConfig):
def __init__(self, config: BasicArithmeticDatasetConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
@ -183,9 +183,9 @@ def basic_arithmetic_dataset(
format_style: Style of question formatting ("simple" or "natural")
Returns:
ArithmeticDataset: Configured dataset instance
BasicArithmeticDataset: Configured dataset instance
"""
config = ArithmeticDatasetConfig(
config = BasicArithmeticDatasetConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,

View file

@ -19,7 +19,7 @@ class FractionSimplificationConfig:
def validate(self):
"""Validate configuration parameters"""
assert self.min_value >= 0, "min_value must be positive"
assert self.min_value > 0, "min_value must be positive"
assert self.max_value > self.min_value, "max_value must be > min_value"
assert self.min_factor >= 1, "min_factor must be at least 1"
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"