mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
added family relatinship curriculum (#359)
This commit is contained in:
parent
fa7d8e66b3
commit
80ac138455
3 changed files with 78 additions and 7 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
|
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
|
||||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
|
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset
|
||||||
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
|
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
|
||||||
from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset
|
from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset
|
||||||
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
|
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
|
||||||
|
|
@ -7,6 +7,7 @@ from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestP
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FamilyRelationshipsConfig",
|
"FamilyRelationshipsConfig",
|
||||||
"FamilyRelationshipsDataset",
|
"FamilyRelationshipsDataset",
|
||||||
|
"FamilyRelationshipsCurriculum",
|
||||||
"QuantumLockConfig",
|
"QuantumLockConfig",
|
||||||
"QuantumLockDataset",
|
"QuantumLockDataset",
|
||||||
"QuantumLockCurriculum",
|
"QuantumLockCurriculum",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from enum import StrEnum
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,9 +184,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
||||||
# Generate family tree
|
# Generate family tree
|
||||||
family = self._generate_family(rng)
|
family = self._generate_family(rng, family_size)
|
||||||
|
|
||||||
# Select two people and their relationship
|
# Select two people and their relationship
|
||||||
person1, person2, relationship = self._get_relationship_question(rng, family)
|
person1, person2, relationship = self._get_relationship_question(rng, family)
|
||||||
|
|
@ -204,12 +205,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
"person2": person2.name,
|
"person2": person2.name,
|
||||||
"relationship": relationship.value,
|
"relationship": relationship.value,
|
||||||
"family_size": len(family),
|
"family_size": len(family),
|
||||||
|
"difficulty": {
|
||||||
|
"family_size": len(family),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate_family(self, rng: random.Random) -> set[Person]:
|
def _generate_family(self, rng: random.Random, family_size: int) -> set[Person]:
|
||||||
"""Generate a random family tree"""
|
"""Generate a random family tree"""
|
||||||
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
|
||||||
family = set()
|
family = set()
|
||||||
used_names = set()
|
used_names = set()
|
||||||
|
|
||||||
|
|
@ -369,4 +372,23 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)
|
class FamilyRelationshipsCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(FamilyRelationshipsCurriculum.__name__, FamilyRelationshipsConfig)
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="family_size",
|
||||||
|
description="The size of the family",
|
||||||
|
min_value=3,
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
default_level=0,
|
||||||
|
levels=list(range(3, 12)),
|
||||||
|
lower_field_name="min_family_size",
|
||||||
|
upper_field_name="max_family_size",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
"family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym import create_dataset
|
from reasoning_gym import create_dataset
|
||||||
from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship
|
from reasoning_gym.graphs.family_relationships import (
|
||||||
|
FamilyRelationshipsConfig,
|
||||||
|
FamilyRelationshipsCurriculum,
|
||||||
|
FamilyRelationshipsDataset,
|
||||||
|
Relationship,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_family_relationships_generation():
|
def test_family_relationships_generation():
|
||||||
|
|
@ -83,3 +88,46 @@ def test_relationship_consistency():
|
||||||
Relationship.FATHER_IN_LAW.value,
|
Relationship.FATHER_IN_LAW.value,
|
||||||
]:
|
]:
|
||||||
assert "married to" in item["question"] or "child" in item["question"]
|
assert "married to" in item["question"] or "child" in item["question"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_family_relationships_curriculum():
|
||||||
|
"""Test the family relationships curriculum functionality"""
|
||||||
|
curriculum = FamilyRelationshipsCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 50, "seed": 42}
|
||||||
|
|
||||||
|
# Test default configuration
|
||||||
|
base_cfg: FamilyRelationshipsConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 42
|
||||||
|
assert base_cfg.size == 50
|
||||||
|
assert base_cfg.min_family_size == 3 and base_cfg.max_family_size == 3 # Default level 0
|
||||||
|
|
||||||
|
# Test incrementing family_size attribute
|
||||||
|
curriculum.increment_attr_level("family_size")
|
||||||
|
first_level_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert first_level_cfg.min_family_size == 3 and first_level_cfg.max_family_size == 4 # Level 1
|
||||||
|
|
||||||
|
# Test incrementing family_size attribute again
|
||||||
|
curriculum.increment_attr_level("family_size")
|
||||||
|
second_level_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert second_level_cfg.min_family_size == 3 and second_level_cfg.max_family_size == 5 # Level 2
|
||||||
|
|
||||||
|
# Test decrementing family_size attribute
|
||||||
|
curriculum.decrement_attr_level("family_size")
|
||||||
|
back_to_first_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert back_to_first_cfg.min_family_size == 3 and back_to_first_cfg.max_family_size == 4 # Back to level 1
|
||||||
|
|
||||||
|
# Test global level setting
|
||||||
|
curriculum.set_global_level(5) # Set to level 5
|
||||||
|
level_five_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert level_five_cfg.min_family_size == 3 and level_five_cfg.max_family_size == 8 # Level 5
|
||||||
|
|
||||||
|
# Test increment global level
|
||||||
|
curriculum.increment_global_level() # Should go to level 6
|
||||||
|
level_six_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert level_six_cfg.min_family_size == 3 and level_six_cfg.max_family_size == 9 # Level 6
|
||||||
|
|
||||||
|
# Test decrement global level
|
||||||
|
curriculum.decrement_global_level() # Should go back to level 5
|
||||||
|
back_to_five_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert back_to_five_cfg.min_family_size == 3 and back_to_five_cfg.max_family_size == 8 # Back to level 5
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue