From 80ac138455c9d20d801d94c68c8c19c68c9341b3 Mon Sep 17 00:00:00 2001 From: joesharratt1229 <118444587+joesharratt1229@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:08:49 +0000 Subject: [PATCH] added family relatinship curriculum (#359) --- reasoning_gym/graphs/__init__.py | 3 +- reasoning_gym/graphs/family_relationships.py | 32 +++++++++++-- tests/test_family_relationships.py | 50 +++++++++++++++++++- 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index 43cab6a7..d8e2b825 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,5 +1,5 @@ from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset -from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset +from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset @@ -7,6 +7,7 @@ from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestP __all__ = [ "FamilyRelationshipsConfig", "FamilyRelationshipsDataset", + "FamilyRelationshipsCurriculum", "QuantumLockConfig", "QuantumLockDataset", "QuantumLockCurriculum", diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 51c0055c..f22cf7a3 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -4,6 +4,7 @@ from enum import StrEnum from itertools import count from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -183,9 +184,9 @@ class FamilyRelationshipsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: rng = random.Random(self.seed + idx) - + family_size = rng.randint(self.config.min_family_size, self.config.max_family_size) # Generate family tree - family = self._generate_family(rng) + family = self._generate_family(rng, family_size) # Select two people and their relationship person1, person2, relationship = self._get_relationship_question(rng, family) @@ -204,12 +205,14 @@ class FamilyRelationshipsDataset(ProceduralDataset): "person2": person2.name, "relationship": relationship.value, "family_size": len(family), + "difficulty": { + "family_size": len(family), + }, }, } - def _generate_family(self, rng: random.Random) -> set[Person]: + def _generate_family(self, rng: random.Random, family_size: int) -> set[Person]: """Generate a random family tree""" - family_size = rng.randint(self.config.min_family_size, self.config.max_family_size) family = set() used_names = set() @@ -369,4 +372,23 @@ class FamilyRelationshipsDataset(ProceduralDataset): return reward -register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig) +class FamilyRelationshipsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(FamilyRelationshipsCurriculum.__name__, FamilyRelationshipsConfig) + self._define_attributes( + RangeAttributeDefinition( + name="family_size", + description="The size of the family", + min_value=3, + attr_type=AttributeType.APPEND, + default_level=0, + levels=list(range(3, 12)), + lower_field_name="min_family_size", + upper_field_name="max_family_size", + ) + ) + + +register_dataset( + "family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum +) diff --git a/tests/test_family_relationships.py b/tests/test_family_relationships.py index e4954113..b3793e99 100644 --- a/tests/test_family_relationships.py +++ b/tests/test_family_relationships.py @@ -1,7 +1,12 @@ import pytest from reasoning_gym import create_dataset -from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship +from reasoning_gym.graphs.family_relationships import ( + FamilyRelationshipsConfig, + FamilyRelationshipsCurriculum, + FamilyRelationshipsDataset, + Relationship, +) def test_family_relationships_generation(): @@ -83,3 +88,46 @@ def test_relationship_consistency(): Relationship.FATHER_IN_LAW.value, ]: assert "married to" in item["question"] or "child" in item["question"] + + +def test_family_relationships_curriculum(): + """Test the family relationships curriculum functionality""" + curriculum = FamilyRelationshipsCurriculum() + + base_value = {"size": 50, "seed": 42} + + # Test default configuration + base_cfg: FamilyRelationshipsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 42 + assert base_cfg.size == 50 + assert base_cfg.min_family_size == 3 and base_cfg.max_family_size == 3 # Default level 0 + + # Test incrementing family_size attribute + curriculum.increment_attr_level("family_size") + first_level_cfg = curriculum.generate_configuration(base_value) + assert first_level_cfg.min_family_size == 3 and first_level_cfg.max_family_size == 4 # Level 1 + + # Test incrementing family_size attribute again + curriculum.increment_attr_level("family_size") + second_level_cfg = curriculum.generate_configuration(base_value) + assert second_level_cfg.min_family_size == 3 and second_level_cfg.max_family_size == 5 # Level 2 + + # Test decrementing family_size attribute + curriculum.decrement_attr_level("family_size") + back_to_first_cfg = curriculum.generate_configuration(base_value) + assert back_to_first_cfg.min_family_size == 3 and back_to_first_cfg.max_family_size == 4 # Back to level 1 + + # Test global level setting + curriculum.set_global_level(5) # Set to level 5 + level_five_cfg = curriculum.generate_configuration(base_value) + assert level_five_cfg.min_family_size == 3 and level_five_cfg.max_family_size == 8 # Level 5 + + # Test increment global level + curriculum.increment_global_level() # Should go to level 6 + level_six_cfg = curriculum.generate_configuration(base_value) + assert level_six_cfg.min_family_size == 3 and level_six_cfg.max_family_size == 9 # Level 6 + + # Test decrement global level + curriculum.decrement_global_level() # Should go back to level 5 + back_to_five_cfg = curriculum.generate_configuration(base_value) + assert back_to_five_cfg.min_family_size == 3 and back_to_five_cfg.max_family_size == 8 # Back to level 5