added family relatinship curriculum (#359)

This commit is contained in:
joesharratt1229 2025-03-14 15:08:49 +00:00 committed by GitHub
parent fa7d8e66b3
commit 80ac138455
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
@ -7,6 +7,7 @@ from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestP
__all__ = [ __all__ = [
"FamilyRelationshipsConfig", "FamilyRelationshipsConfig",
"FamilyRelationshipsDataset", "FamilyRelationshipsDataset",
"FamilyRelationshipsCurriculum",
"QuantumLockConfig", "QuantumLockConfig",
"QuantumLockDataset", "QuantumLockDataset",
"QuantumLockCurriculum", "QuantumLockCurriculum",

View file

@ -4,6 +4,7 @@ from enum import StrEnum
from itertools import count from itertools import count
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -183,9 +184,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
# Generate family tree # Generate family tree
family = self._generate_family(rng) family = self._generate_family(rng, family_size)
# Select two people and their relationship # Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family) person1, person2, relationship = self._get_relationship_question(rng, family)
@ -204,12 +205,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"person2": person2.name, "person2": person2.name,
"relationship": relationship.value, "relationship": relationship.value,
"family_size": len(family), "family_size": len(family),
"difficulty": {
"family_size": len(family),
},
}, },
} }
def _generate_family(self, rng: random.Random) -> set[Person]: def _generate_family(self, rng: random.Random, family_size: int) -> set[Person]:
"""Generate a random family tree""" """Generate a random family tree"""
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
family = set() family = set()
used_names = set() used_names = set()
@ -369,4 +372,23 @@ class FamilyRelationshipsDataset(ProceduralDataset):
return reward return reward
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig) class FamilyRelationshipsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FamilyRelationshipsCurriculum.__name__, FamilyRelationshipsConfig)
self._define_attributes(
RangeAttributeDefinition(
name="family_size",
description="The size of the family",
min_value=3,
attr_type=AttributeType.APPEND,
default_level=0,
levels=list(range(3, 12)),
lower_field_name="min_family_size",
upper_field_name="max_family_size",
)
)
register_dataset(
"family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum
)

View file

@ -1,7 +1,12 @@
import pytest import pytest
from reasoning_gym import create_dataset from reasoning_gym import create_dataset
from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship from reasoning_gym.graphs.family_relationships import (
FamilyRelationshipsConfig,
FamilyRelationshipsCurriculum,
FamilyRelationshipsDataset,
Relationship,
)
def test_family_relationships_generation(): def test_family_relationships_generation():
@ -83,3 +88,46 @@ def test_relationship_consistency():
Relationship.FATHER_IN_LAW.value, Relationship.FATHER_IN_LAW.value,
]: ]:
assert "married to" in item["question"] or "child" in item["question"] assert "married to" in item["question"] or "child" in item["question"]
def test_family_relationships_curriculum():
"""Test the family relationships curriculum functionality"""
curriculum = FamilyRelationshipsCurriculum()
base_value = {"size": 50, "seed": 42}
# Test default configuration
base_cfg: FamilyRelationshipsConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 42
assert base_cfg.size == 50
assert base_cfg.min_family_size == 3 and base_cfg.max_family_size == 3 # Default level 0
# Test incrementing family_size attribute
curriculum.increment_attr_level("family_size")
first_level_cfg = curriculum.generate_configuration(base_value)
assert first_level_cfg.min_family_size == 3 and first_level_cfg.max_family_size == 4 # Level 1
# Test incrementing family_size attribute again
curriculum.increment_attr_level("family_size")
second_level_cfg = curriculum.generate_configuration(base_value)
assert second_level_cfg.min_family_size == 3 and second_level_cfg.max_family_size == 5 # Level 2
# Test decrementing family_size attribute
curriculum.decrement_attr_level("family_size")
back_to_first_cfg = curriculum.generate_configuration(base_value)
assert back_to_first_cfg.min_family_size == 3 and back_to_first_cfg.max_family_size == 4 # Back to level 1
# Test global level setting
curriculum.set_global_level(5) # Set to level 5
level_five_cfg = curriculum.generate_configuration(base_value)
assert level_five_cfg.min_family_size == 3 and level_five_cfg.max_family_size == 8 # Level 5
# Test increment global level
curriculum.increment_global_level() # Should go to level 6
level_six_cfg = curriculum.generate_configuration(base_value)
assert level_six_cfg.min_family_size == 3 and level_six_cfg.max_family_size == 9 # Level 6
# Test decrement global level
curriculum.decrement_global_level() # Should go back to level 5
back_to_five_cfg = curriculum.generate_configuration(base_value)
assert back_to_five_cfg.min_family_size == 3 and back_to_five_cfg.max_family_size == 8 # Back to level 5