add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent e9549f2a63
commit 0dcff77b37
35 changed files with 133 additions and 598 deletions

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
ANIMALS = {
# Animals with 0 legs
@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset):
}
def leg_counting_dataset(
min_animals: int = 2,
max_animals: int = 5,
max_instances: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> LegCountingDataset:
"""Create a LegCountingDataset with the given configuration."""
config = LegCountingConfig(
min_animals=min_animals,
max_animals=max_animals,
max_instances=max_instances,
seed=seed,
size=size,
)
return LegCountingDataset(config)
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)