mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
added family relatinship curriculum (#359)
This commit is contained in:
parent
fa7d8e66b3
commit
80ac138455
3 changed files with 78 additions and 7 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
|
||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
|
||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset
|
||||
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
|
||||
from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset
|
||||
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
|
||||
|
|
@ -7,6 +7,7 @@ from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestP
|
|||
__all__ = [
|
||||
"FamilyRelationshipsConfig",
|
||||
"FamilyRelationshipsDataset",
|
||||
"FamilyRelationshipsCurriculum",
|
||||
"QuantumLockConfig",
|
||||
"QuantumLockDataset",
|
||||
"QuantumLockCurriculum",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from enum import StrEnum
|
|||
from itertools import count
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -183,9 +184,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
||||
# Generate family tree
|
||||
family = self._generate_family(rng)
|
||||
family = self._generate_family(rng, family_size)
|
||||
|
||||
# Select two people and their relationship
|
||||
person1, person2, relationship = self._get_relationship_question(rng, family)
|
||||
|
|
@ -204,12 +205,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
"person2": person2.name,
|
||||
"relationship": relationship.value,
|
||||
"family_size": len(family),
|
||||
"difficulty": {
|
||||
"family_size": len(family),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_family(self, rng: random.Random) -> set[Person]:
|
||||
def _generate_family(self, rng: random.Random, family_size: int) -> set[Person]:
|
||||
"""Generate a random family tree"""
|
||||
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
||||
family = set()
|
||||
used_names = set()
|
||||
|
||||
|
|
@ -369,4 +372,23 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)
|
||||
class FamilyRelationshipsCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(FamilyRelationshipsCurriculum.__name__, FamilyRelationshipsConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="family_size",
|
||||
description="The size of the family",
|
||||
min_value=3,
|
||||
attr_type=AttributeType.APPEND,
|
||||
default_level=0,
|
||||
levels=list(range(3, 12)),
|
||||
lower_field_name="min_family_size",
|
||||
upper_field_name="max_family_size",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue