refactor: Inherit ChainSum from ProceduralDataset base class

This commit is contained in:
Andreas Koepf (aider) 2025-01-24 09:57:26 +01:00
parent 0dc19b831c
commit d191e78a28

View file

@ -1,6 +1,7 @@
import random
from dataclasses import dataclass
from typing import Optional
from ..dataset import ProceduralDataset
@dataclass
@ -27,17 +28,13 @@ class ChainSumConfig:
assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range"
class ChainSum:
class ChainSum(ProceduralDataset):
"""Generates simple arithmetic tasks using only + and - operators"""
def __init__(self, config: ChainSumConfig):
self.config = config
self.config.validate()
# Generate base seed if none provided
self.seed = config.seed if config.seed is not None else random.randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
super().__init__(seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single chain sum task
@ -73,18 +70,6 @@ class ChainSum:
},
}
def __iter__(self):
"""Make the dataset iterable"""
self._current_idx = 0
return self
def __next__(self):
"""Get next item in iteration"""
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
"""Generate a chain sum task