add ProductsDataset (multiplication tasks)

This commit is contained in:
Andreas Koepf 2025-02-13 17:59:02 +01:00
parent 17485fad67
commit 1996ffa6d8
10 changed files with 56 additions and 56 deletions

View file

@ -26,7 +26,7 @@ class ProductsConfig:
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
class Products(ProceduralDataset):
class ProductsDataset(ProceduralDataset):
"""Generates multiplication tasks with configurable number of terms"""
def __init__(self, config: ProductsConfig):
@ -45,16 +45,16 @@ class Products(ProceduralDataset):
- metadata: dict with generation parameters
"""
# Create deterministic RNG from base seed and idx
item_rng = random.Random(self.seed + idx)
rng = random.Random(self.seed + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
# Calculate value ranges based on number of digits
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
expression, result = self._generate_task(item_rng, num_terms, min_value, max_value)
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
@ -127,4 +127,4 @@ class ProductsCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("products", Products, ProductsConfig)
register_dataset("products", ProductsDataset, ProductsConfig)