add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -1,3 +1,3 @@
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset, family_relationships_dataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig", "family_relationships_dataset"]
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig"]

View file

@ -4,7 +4,7 @@ from enum import StrEnum
from itertools import count
from typing import Dict, List, Optional, Set, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
class Gender(StrEnum):
@ -310,21 +310,4 @@ class FamilyRelationshipsDataset(ProceduralDataset):
return " ".join(story_parts)
def family_relationships_dataset(
min_family_size: int = 4,
max_family_size: int = 8,
male_names: List[str] = None,
female_names: List[str] = None,
seed: Optional[int] = None,
size: int = 500,
) -> FamilyRelationshipsDataset:
"""Create a FamilyRelationshipsDataset with the given configuration"""
config = FamilyRelationshipsConfig(
min_family_size=min_family_size,
max_family_size=max_family_size,
male_names=male_names,
female_names=female_names,
seed=seed,
size=size,
)
return FamilyRelationshipsDataset(config)
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)