pass config to ProceduralDataset base

This commit is contained in:
Andreas Koepf 2025-01-25 00:23:05 +01:00
parent df2b8d2809
commit e9549f2a63
20 changed files with 45 additions and 80 deletions

View file

@ -1,9 +1,8 @@
import random
import string
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import sympy
from sympy import Eq, Symbol, expand, solve
from ..dataset import ProceduralDataset
@ -28,7 +27,7 @@ class PolynomialEquationsConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_terms > 0, "min_terms must be positive."
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
@ -53,15 +52,13 @@ class PolynomialEquationsDataset(ProceduralDataset):
"""
def __init__(self, config: PolynomialEquationsConfig):
config.validate()
self.config = config
self._prompt_templates = [
"Find the real value(s) of {variable} in the equation: {polynomial_expanded} = 0",
"Solve for real {variable}: {polynomial_expanded} = 0",
"Determine the real value(s) of {variable} tha satisfies: {polynomial_expanded} = 0",
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
]
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""

View file

@ -21,7 +21,7 @@ class SimpleEquationsConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
@ -35,14 +35,12 @@ class SimpleEquationsDataset(ProceduralDataset):
"""Generates simple equations with one variable to solve"""
def __init__(self, config: SimpleEquationsConfig):
self.config = config
self.config.validate()
self._prompt_templates = [
"Find the value of {variable} in the equation: {equation}",
"Solve for {variable}: {equation}",
"Determine the value of {variable} that satisfies: {equation}",
]
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single equation task

View file

@ -18,7 +18,7 @@ class BaseConversionConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert 2 <= self.min_base <= 36, "min_base must be between 2 and 36"
assert self.min_base <= self.max_base <= 36, "max_base must be between min_base and 36"
@ -30,9 +30,7 @@ class BaseConversionDataset(ProceduralDataset):
"""Generates base conversion tasks"""
def __init__(self, config: BaseConversionConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _format_base_name(self, base: int) -> str:
"""Get human-readable name for common bases"""

View file

@ -19,7 +19,7 @@ class LetterCountingConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_words > 0, "min_words must be positive"
assert self.max_words >= self.min_words, "max_words must be >= min_words"
@ -29,9 +29,7 @@ class LetterCountingDataset(ProceduralDataset):
"""Generates letter counting tasks from text spans"""
def __init__(self, config: LetterCountingConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")

View file

@ -20,7 +20,7 @@ class NumberFilteringConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers > 0, "min_numbers must be positive"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
@ -33,9 +33,7 @@ class NumberFilteringDataset(ProceduralDataset):
"""Generates number filtering tasks"""
def __init__(self, config: NumberFilteringConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _format_number(self, num: float, decimals: int) -> str:
"""Format a number with specified decimal places"""

View file

@ -20,7 +20,7 @@ class NumberSortingConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers > 0, "min_numbers must be positive"
assert self.min_numbers <= self.max_numbers, "max_numbers must be >= min_numbers"
@ -33,9 +33,7 @@ class NumberSortingDataset(ProceduralDataset):
"""Generates number sorting tasks"""
def __init__(self, config: NumberSortingConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _format_number(self, num: float, decimals: int) -> str:
"""Format number with specified decimal places"""

View file

@ -18,7 +18,7 @@ class WordReversalConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_words > 0, "min_words must be positive"
assert self.max_words >= self.min_words, "max_words must be >= min_words"
@ -28,9 +28,7 @@ class WordReversalDataset(ProceduralDataset):
"""Generates word reversal tasks from text spans"""
def __init__(self, config: WordReversalConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")

View file

@ -21,7 +21,7 @@ class BasicArithmeticDatasetConfig:
format_style: Literal["simple", "natural"] = "simple"
whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
@ -63,9 +63,7 @@ class BasicArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
def __init__(self, config: BasicArithmeticDatasetConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task

View file

@ -18,7 +18,7 @@ class ChainSumConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
@ -34,9 +34,7 @@ class ChainSum(ProceduralDataset):
"""Generates simple arithmetic tasks using only + and - operators"""
def __init__(self, config: ChainSumConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single chain sum task
@ -145,5 +143,6 @@ def chain_sum_dataset(
)
return ChainSum(config)
# Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig)

View file

@ -20,7 +20,7 @@ class FractionSimplificationConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_value > 0, "min_value must be positive"
assert self.max_value > self.min_value, "max_value must be > min_value"
@ -37,9 +37,7 @@ class FractionSimplificationDataset(ProceduralDataset):
"""Generates fraction simplification tasks"""
def __init__(self, config: FractionSimplificationConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
"""Generate a random fraction and its simplified form.

View file

@ -20,7 +20,7 @@ class GCDConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers >= 2, "min_numbers must be at least 2"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
@ -32,9 +32,7 @@ class GCDDataset(ProceduralDataset):
"""Generates Greatest Common Divisor (GCD) tasks"""
def __init__(self, config: GCDConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their GCD.

View file

@ -20,7 +20,7 @@ class LCMConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers >= 2, "min_numbers must be at least 2"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
@ -32,9 +32,7 @@ class LCMDataset(ProceduralDataset):
"""Generates Least Common Multiple (LCM) tasks"""
def __init__(self, config: LCMConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their LCM.

View file

@ -65,7 +65,7 @@ class LegCountingConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_animals > 0, "min_animals must be positive"
assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals"
@ -76,9 +76,7 @@ class LegCountingDataset(ProceduralDataset):
"""Generates leg counting arithmetic tasks"""
def __init__(self, config: LegCountingConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_animals(self, rng: Random) -> Dict[str, int]:
"""Generate a random set of animals and their counts"""

View file

@ -16,7 +16,7 @@ class PrimeFactorizationConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_value >= 2, "min_value must be >= 2"
assert self.max_value >= self.min_value, "max_value must be >= min_value"
@ -26,9 +26,7 @@ class PrimeFactorizationDataset(ProceduralDataset):
"""Generates prime factorization tasks"""
def __init__(self, config: PrimeFactorizationConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def _prime_factors(self, n: int) -> List[int]:
"""Compute prime factors of a number"""

View file

@ -95,7 +95,7 @@ class ColorCubeRotationConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_rotations > 0, "min_rotations must be positive"
assert self.max_rotations >= self.min_rotations, "max_rotations must be >= min_rotations"
@ -105,9 +105,7 @@ class ColorCubeRotationDataset(ProceduralDataset):
"""Generates color cube rotation reasoning tasks"""
def __init__(self, config: ColorCubeRotationConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)

View file

@ -31,7 +31,7 @@ class NumberSequenceConfig:
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms >= 4, "need at least 4 terms to establish pattern"
assert self.max_terms >= self.min_terms
@ -155,9 +155,7 @@ class NumberSequenceDataset(ProceduralDataset):
"""Generates number sequence completion tasks with dynamic pattern generation"""
def __init__(self, config: NumberSequenceConfig):
self.config = config
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a sequence task with a newly generated pattern"""

View file

@ -11,6 +11,9 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
def __init__(self, config: Any, seed: Optional[int] = None, size: int = 500):
"""Initialize the dataset with config, optional seed and size"""
if hasattr(config, "validate") and callable(config.validate):
config.validate()
self.config = config
self.size = size
self.seed = seed if seed is not None else Random().randint(0, 2**32)

View file

@ -26,7 +26,7 @@ class MazeConfig:
seed: Optional[int] = None
size: int = 50
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_dist >= 1, "min_dist must be >= 1"
assert self.max_dist >= self.min_dist, "max_dist must be >= min_dist"
@ -46,8 +46,7 @@ class MazeDataset(ProceduralDataset):
prob_path=0.7,
num_retries=1000,
):
config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
self.config = config
# Probability that a cell is a path instead of a wall
self.prob_path = prob_path

View file

@ -166,7 +166,7 @@ class FamilyRelationshipsConfig:
if self.female_names is None:
self.female_names = default_female_names
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_family_size >= 3, "min_family_size must be at least 3"
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
@ -178,14 +178,12 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"""Generates family relationship reasoning tasks"""
def __init__(self, config: FamilyRelationshipsConfig):
self.config = config
self.config.validate()
self._templates = [
"What is {person1} to {person2}?",
"How is {person1} related to {person2}?",
"What relation is {person1} to {person2}?",
]
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)

View file

@ -45,7 +45,7 @@ class SyllogismConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert any(
[self.allow_all, self.allow_no, self.allow_some, self.allow_some_not]
@ -100,11 +100,8 @@ class SyllogismDataset(ProceduralDataset):
]
def __init__(self, config: SyllogismConfig):
self.config = config
if self.config.terms is None:
self.config.terms = self.DEFAULT_TERMS
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms
def _get_allowed_quantifiers(self) -> List[Quantifier]:
"""Get list of allowed quantifiers based on config"""
@ -212,7 +209,7 @@ class SyllogismDataset(ProceduralDataset):
def _generate_syllogism(self, rng: Random) -> dict:
"""Generate a single syllogism problem"""
# Select three different terms
terms = rng.sample(self.config.terms, 3)
terms = rng.sample(self.terms, 3)
quantifiers = self._get_allowed_quantifiers()
# Generate premises and conclusion