added family relatinship curriculum (#359)

This commit is contained in:
joesharratt1229 2025-03-14 15:08:49 +00:00 committed by GitHub
parent fa7d8e66b3
commit 80ac138455
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
@ -7,6 +7,7 @@ from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestP
__all__ = [
"FamilyRelationshipsConfig",
"FamilyRelationshipsDataset",
"FamilyRelationshipsCurriculum",
"QuantumLockConfig",
"QuantumLockDataset",
"QuantumLockCurriculum",

View file

@ -4,6 +4,7 @@ from enum import StrEnum
from itertools import count
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -183,9 +184,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
# Generate family tree
family = self._generate_family(rng)
family = self._generate_family(rng, family_size)
# Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family)
@ -204,12 +205,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"person2": person2.name,
"relationship": relationship.value,
"family_size": len(family),
"difficulty": {
"family_size": len(family),
},
},
}
def _generate_family(self, rng: random.Random) -> set[Person]:
def _generate_family(self, rng: random.Random, family_size: int) -> set[Person]:
"""Generate a random family tree"""
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
family = set()
used_names = set()
@ -369,4 +372,23 @@ class FamilyRelationshipsDataset(ProceduralDataset):
return reward
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)
class FamilyRelationshipsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FamilyRelationshipsCurriculum.__name__, FamilyRelationshipsConfig)
self._define_attributes(
RangeAttributeDefinition(
name="family_size",
description="The size of the family",
min_value=3,
attr_type=AttributeType.APPEND,
default_level=0,
levels=list(range(3, 12)),
lower_field_name="min_family_size",
upper_field_name="max_family_size",
)
)
register_dataset(
"family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum
)