diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 81e81583..9c686771 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -233,18 +233,29 @@ class FamilyRelationshipsDataset(ProceduralDataset): # Create ID counter id_counter = count() - # Create grandparents generation - grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) - grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) - grandfather.add_spouse(grandmother) - family.update([grandfather, grandmother]) + # Create paternal grandparents generation + grandfather_of_father = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) + grandmother_of_father = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) + grandfather_of_father.add_spouse(grandmother_of_father) + family.update([grandfather_of_father, grandmother_of_father]) + + # Create maternal grandparents generation + grandfather_of_mother = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) + grandmother_of_mother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) + grandfather_of_mother.add_spouse(grandmother_of_mother) + family.update([grandfather_of_mother, grandmother_of_mother]) # Create parents father = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) mother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) father.add_spouse(mother) - grandfather.add_child(father) - grandmother.add_child(father) + + # Link parents to their respective parents + grandfather_of_father.add_child(father) + grandmother_of_father.add_child(father) + grandfather_of_mother.add_child(mother) + grandmother_of_mother.add_child(mother) + family.update([father, mother]) # Add children