add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -6,14 +6,12 @@ Cognition tasks for training reasoning capabilities:
- Working memory
"""
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset, color_cube_rotation_dataset
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset, number_sequence_dataset
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
__all__ = [
"NumberSequenceConfig",
"NumberSequenceDataset",
"number_sequence_dataset",
"ColorCubeRotationConfig",
"ColorCubeRotationDataset",
"color_cube_rotation_dataset",
]

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import StrEnum
from typing import Dict, List, Optional, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
class Color(StrEnum):
@ -189,17 +189,4 @@ class ColorCubeRotationDataset(ProceduralDataset):
return "\n".join(story_parts)
def color_cube_rotation_dataset(
min_rotations: int = 1,
max_rotations: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> ColorCubeRotationDataset:
"""Create a ColorCubeRotationDataset with the given configuration"""
config = ColorCubeRotationConfig(
min_rotations=min_rotations,
max_rotations=max_rotations,
seed=seed,
size=size,
)
return ColorCubeRotationDataset(config)
register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig)

View file

@ -3,7 +3,7 @@ from enum import StrEnum
from random import Random
from typing import List, Optional
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
class Operation(StrEnum):
@ -198,23 +198,4 @@ class NumberSequenceDataset(ProceduralDataset):
}
def number_sequence_dataset(
min_terms: int = 4,
max_terms: int = 8,
min_value: int = -100,
max_value: int = 100,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> NumberSequenceDataset:
"""Create a NumberSequenceDataset with the given configuration."""
config = NumberSequenceConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return NumberSequenceDataset(config)
register_dataset("number_sequence", NumberSequenceDataset, NumberSequenceConfig)