pass config to ProceduralDataset base

This commit is contained in:
Andreas Koepf 2025-01-25 00:23:05 +01:00
parent df2b8d2809
commit e9549f2a63
20 changed files with 45 additions and 80 deletions

View file

@ -45,7 +45,7 @@ class SyllogismConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert any(
[self.allow_all, self.allow_no, self.allow_some, self.allow_some_not]
@ -100,11 +100,8 @@ class SyllogismDataset(ProceduralDataset):
]
def __init__(self, config: SyllogismConfig):
self.config = config
if self.config.terms is None:
self.config.terms = self.DEFAULT_TERMS
self.config.validate()
super().__init__(seed=config.seed, size=config.size)
super().__init__(config=config, seed=config.seed, size=config.size)
self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms
def _get_allowed_quantifiers(self) -> List[Quantifier]:
"""Get list of allowed quantifiers based on config"""
@ -212,7 +209,7 @@ class SyllogismDataset(ProceduralDataset):
def _generate_syllogism(self, rng: Random) -> dict:
"""Generate a single syllogism problem"""
# Select three different terms
terms = rng.sample(self.config.terms, 3)
terms = rng.sample(self.terms, 3)
quantifiers = self._get_allowed_quantifiers()
# Generate premises and conclusion