add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -5,6 +5,8 @@ from enum import StrEnum
from random import Random
from typing import Any, List, Optional, Set
from ..factory import ProceduralDataset, register_dataset
class Operator(StrEnum):
"""Basic logical operators"""
@ -70,13 +72,11 @@ class Expression:
return f"({self.left} {self.operator.value} {self.right})"
class PropositionalLogicDataset:
class PropositionalLogicDataset(ProceduralDataset):
"""Generates propositional logic reasoning tasks"""
def __init__(self, config: PropositionalLogicConfig):
self.config = config
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
@ -199,23 +199,4 @@ class PropositionalLogicDataset:
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
def propositional_logic_dataset(
min_vars: int = 2,
max_vars: int = 4,
min_statements: int = 2,
max_statements: int = 4,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> PropositionalLogicDataset:
"""Create a PropositionalLogicDataset with the given configuration."""
config = PropositionalLogicConfig(
min_vars=min_vars,
max_vars=max_vars,
min_statements=min_statements,
max_statements=max_statements,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return PropositionalLogicDataset(config)
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)