mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
pass config to ProceduralDataset base
This commit is contained in:
parent
df2b8d2809
commit
e9549f2a63
20 changed files with 45 additions and 80 deletions
|
|
@ -45,7 +45,7 @@ class SyllogismConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert any(
|
||||
[self.allow_all, self.allow_no, self.allow_some, self.allow_some_not]
|
||||
|
|
@ -100,11 +100,8 @@ class SyllogismDataset(ProceduralDataset):
|
|||
]
|
||||
|
||||
def __init__(self, config: SyllogismConfig):
|
||||
self.config = config
|
||||
if self.config.terms is None:
|
||||
self.config.terms = self.DEFAULT_TERMS
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms
|
||||
|
||||
def _get_allowed_quantifiers(self) -> List[Quantifier]:
|
||||
"""Get list of allowed quantifiers based on config"""
|
||||
|
|
@ -212,7 +209,7 @@ class SyllogismDataset(ProceduralDataset):
|
|||
def _generate_syllogism(self, rng: Random) -> dict:
|
||||
"""Generate a single syllogism problem"""
|
||||
# Select three different terms
|
||||
terms = rng.sample(self.config.terms, 3)
|
||||
terms = rng.sample(self.terms, 3)
|
||||
quantifiers = self._get_allowed_quantifiers()
|
||||
|
||||
# Generate premises and conclusion
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue