diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 828f538e..155242a6 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -272,15 +272,18 @@ class FamilyRelationshipsDataset(ProceduralDataset): family.update([father, mother, uncle, aunt_by_marriage, aunt, uncle_by_marriage]) - # Add children + # Add children, randomly assigned to couples + couples = [(father, mother), (uncle, aunt_by_marriage), (aunt, uncle_by_marriage)] while len(family) < family_size: gender = rng.choice([Gender.MALE, Gender.FEMALE]) name = get_name(gender) if not name: break child = Person(name, gender, next(id_counter)) - father.add_child(child) - mother.add_child(child) + # Randomly choose parents for this child + parents = rng.choice(couples) + parents[0].add_child(child) # Add to father/uncle/aunt + parents[1].add_child(child) # Add to mother/aunt_by_marriage/uncle_by_marriage family.add(child) return family