feat: Add family relationships graph module to reasoning_gym

This commit is contained in:
Andreas Koepf 2025-01-24 16:55:52 +01:00 committed by Andreas Koepf (aider)
parent a33b9744ea
commit e2e4e633be

View file

@ -0,0 +1,242 @@
import random
from dataclasses import dataclass
from typing import Optional, Dict, List, Set, Tuple
from enum import Enum
from ..dataset import ProceduralDataset
class Gender(Enum):
MALE = "male"
FEMALE = "female"
class Relationship(Enum):
MOTHER = "Mother"
FATHER = "Father"
SISTER = "Sister"
BROTHER = "Brother"
DAUGHTER = "Daughter"
SON = "Son"
WIFE = "Wife"
HUSBAND = "Husband"
GRANDMOTHER = "Grandmother"
GRANDFATHER = "Grandfather"
@dataclass
class Person:
name: str
gender: Gender
spouse: Optional['Person'] = None
parents: List['Person'] = None
children: List['Person'] = None
def __post_init__(self):
self.parents = self.parents or []
self.children = self.children or []
def add_child(self, child: 'Person'):
if child not in self.children:
self.children.append(child)
if self not in child.parents:
child.parents.append(self)
def add_spouse(self, spouse: 'Person'):
self.spouse = spouse
spouse.spouse = self
@dataclass
class FamilyRelationshipsConfig:
"""Configuration for family relationship task generation"""
min_family_size: int = 4
max_family_size: int = 8
male_names: List[str] = None
female_names: List[str] = None
seed: Optional[int] = None
size: int = 500
def __post_init__(self):
# Default name lists if none provided
self.male_names = self.male_names or [
"James", "John", "Robert", "Michael", "William", "David", "Richard",
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew"
]
self.female_names = self.female_names or [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna"
]
def validate(self):
"""Validate configuration parameters"""
assert self.min_family_size >= 3, "min_family_size must be at least 3"
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
assert len(self.male_names) > 0, "must provide male names"
assert len(self.female_names) > 0, "must provide female names"
class FamilyRelationshipsDataset(ProceduralDataset):
"""Generates family relationship reasoning tasks"""
def __init__(self, config: FamilyRelationshipsConfig):
self.config = config
self.config.validate()
self._templates = [
"What is {person1} to {person2}?",
"How is {person1} related to {person2}?",
"What relation is {person1} to {person2}?",
]
super().__init__(seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
# Generate family tree
family = self._generate_family(rng)
# Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family)
# Generate story describing the family relationships
story = self._generate_story(family)
# Format question
question = rng.choice(self._templates).format(
person1=person1.name,
person2=person2.name
)
return {
"question": f"{story}\n\n{question}",
"answer": relationship.value,
"metadata": {
"person1": person1.name,
"person2": person2.name,
"relationship": relationship.value,
"family_size": len(family)
}
}
def _generate_family(self, rng: random.Random) -> Set[Person]:
"""Generate a random family tree"""
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
family = set()
used_names = set()
def get_name(gender: Gender) -> str:
names = (self.config.male_names if gender == Gender.MALE
else self.config.female_names)
available = [n for n in names if n not in used_names]
if not available:
return None
name = rng.choice(available)
used_names.add(name)
return name
# Create grandparents generation
grandfather = Person(get_name(Gender.MALE), Gender.MALE)
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
grandfather.add_spouse(grandmother)
family.update([grandfather, grandmother])
# Create parents
father = Person(get_name(Gender.MALE), Gender.MALE)
mother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
father.add_spouse(mother)
grandfather.add_child(father)
grandmother.add_child(father)
family.update([father, mother])
# Add children
while len(family) < family_size:
gender = rng.choice([Gender.MALE, Gender.FEMALE])
name = get_name(gender)
if not name:
break
child = Person(name, gender)
father.add_child(child)
mother.add_child(child)
family.add(child)
return family
def _get_relationship_question(
self, rng: random.Random, family: Set[Person]
) -> Tuple[Person, Person, Relationship]:
"""Select two family members and determine their relationship"""
person1, person2 = rng.sample(list(family), 2)
# Determine relationship
if person1 in person2.parents:
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
else Relationship.FATHER)
elif person2 in person1.parents:
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
else Relationship.SON)
elif person1.spouse == person2:
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
else Relationship.HUSBAND)
elif (person1.parents and person2.parents and
set(person1.parents) == set(person2.parents)):
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
else Relationship.BROTHER)
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
else Relationship.GRANDFATHER)
else:
# Try again with different people
return self._get_relationship_question(rng, family)
return person1, person2, relationship
def _generate_story(self, family: Set[Person]) -> str:
"""Generate a story describing the family relationships"""
story_parts = []
# Find married couples
couples = set()
for person in family:
if person.spouse and (person.spouse, person) not in couples:
couples.add((person, person.spouse))
# Describe marriages
for person1, person2 in couples:
story_parts.append(f"{person1.name} is married to {person2.name}.")
# Describe parent-child relationships
for person in family:
if person.children:
children_names = [c.name for c in person.children]
if len(children_names) == 1:
story_parts.append(
f"They have a child called {children_names[0]}."
)
else:
*first, last = children_names
children_str = ", ".join(first) + f" and {last}"
story_parts.append(
f"They have children called {children_str}."
)
return " ".join(story_parts)
def family_relationships_dataset(
min_family_size: int = 4,
max_family_size: int = 8,
male_names: List[str] = None,
female_names: List[str] = None,
seed: Optional[int] = None,
size: int = 500,
) -> FamilyRelationshipsDataset:
"""Create a FamilyRelationshipsDataset with the given configuration"""
config = FamilyRelationshipsConfig(
min_family_size=min_family_size,
max_family_size=max_family_size,
male_names=male_names,
female_names=female_names,
seed=seed,
size=size,
)
return FamilyRelationshipsDataset(config)