diff --git a/README.md b/README.md index e551dffc..80b00849 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ The goal is to generate virtually infinite data with adjustable complexity. #### Logic Tasks - `PropositionalLogicDataset`: Generate propositional logic reasoning problems +#### Relationship Tasks +- `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees + #### Game Tasks - `SudokuDataset`: Generate 9x9 Sudoku puzzles with configurable number of empty cells - `MiniSudokuDataset`: Generate 4x4 Mini Sudoku puzzles with configurable difficulty diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 35065d18..0d3f0e9e 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,7 +2,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import algorithmic, algebra, arithmetic, cognition, data, games, logic +from . import algorithmic, algebra, arithmetic, cognition, data, games, logic, relationships __version__ = "0.1.1" -__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "logic"] +__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "logic", "relationships"] diff --git a/reasoning_gym/relationships/__init__.py b/reasoning_gym/relationships/__init__.py new file mode 100644 index 00000000..bf643127 --- /dev/null +++ b/reasoning_gym/relationships/__init__.py @@ -0,0 +1,11 @@ +from .family_relationships import ( + FamilyRelationshipsDataset, + FamilyRelationshipsConfig, + family_relationships_dataset +) + +__all__ = [ + "FamilyRelationshipsDataset", + "FamilyRelationshipsConfig", + "family_relationships_dataset" +] diff --git a/reasoning_gym/relationships/family_relationships.py b/reasoning_gym/relationships/family_relationships.py new file mode 100644 index 00000000..a19d500d --- /dev/null +++ b/reasoning_gym/relationships/family_relationships.py @@ -0,0 +1,242 @@ +import random +from dataclasses import dataclass +from typing import Optional, Dict, List, Set, Tuple +from enum import Enum + +from ..dataset import ProceduralDataset + + +class Gender(Enum): + MALE = "male" + FEMALE = "female" + + +class Relationship(Enum): + MOTHER = "Mother" + FATHER = "Father" + SISTER = "Sister" + BROTHER = "Brother" + DAUGHTER = "Daughter" + SON = "Son" + WIFE = "Wife" + HUSBAND = "Husband" + GRANDMOTHER = "Grandmother" + GRANDFATHER = "Grandfather" + + +@dataclass +class Person: + name: str + gender: Gender + spouse: Optional['Person'] = None + parents: List['Person'] = None + children: List['Person'] = None + + def __post_init__(self): + self.parents = self.parents or [] + self.children = self.children or [] + + def add_child(self, child: 'Person'): + if child not in self.children: + self.children.append(child) + if self not in child.parents: + child.parents.append(self) + + def add_spouse(self, spouse: 'Person'): + self.spouse = spouse + spouse.spouse = self + + +@dataclass +class FamilyRelationshipsConfig: + """Configuration for family relationship task generation""" + min_family_size: int = 4 + max_family_size: int = 8 + male_names: List[str] = None + female_names: List[str] = None + seed: Optional[int] = None + size: int = 500 + + def __post_init__(self): + # Default name lists if none provided + self.male_names = self.male_names or [ + "James", "John", "Robert", "Michael", "William", "David", "Richard", + "Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew" + ] + self.female_names = self.female_names or [ + "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", + "Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna" + ] + + def validate(self): + """Validate configuration parameters""" + assert self.min_family_size >= 3, "min_family_size must be at least 3" + assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size" + assert len(self.male_names) > 0, "must provide male names" + assert len(self.female_names) > 0, "must provide female names" + + +class FamilyRelationshipsDataset(ProceduralDataset): + """Generates family relationship reasoning tasks""" + + def __init__(self, config: FamilyRelationshipsConfig): + self.config = config + self.config.validate() + self._templates = [ + "What is {person1} to {person2}?", + "How is {person1} related to {person2}?", + "What relation is {person1} to {person2}?", + ] + super().__init__(seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + rng = random.Random(self.seed + idx) + + # Generate family tree + family = self._generate_family(rng) + + # Select two people and their relationship + person1, person2, relationship = self._get_relationship_question(rng, family) + + # Generate story describing the family relationships + story = self._generate_story(family) + + # Format question + question = rng.choice(self._templates).format( + person1=person1.name, + person2=person2.name + ) + + return { + "question": f"{story}\n\n{question}", + "answer": relationship.value, + "metadata": { + "person1": person1.name, + "person2": person2.name, + "relationship": relationship.value, + "family_size": len(family) + } + } + + def _generate_family(self, rng: random.Random) -> Set[Person]: + """Generate a random family tree""" + family_size = rng.randint(self.config.min_family_size, self.config.max_family_size) + family = set() + used_names = set() + + def get_name(gender: Gender) -> str: + names = (self.config.male_names if gender == Gender.MALE + else self.config.female_names) + available = [n for n in names if n not in used_names] + if not available: + return None + name = rng.choice(available) + used_names.add(name) + return name + + # Create grandparents generation + grandfather = Person(get_name(Gender.MALE), Gender.MALE) + grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE) + grandfather.add_spouse(grandmother) + family.update([grandfather, grandmother]) + + # Create parents + father = Person(get_name(Gender.MALE), Gender.MALE) + mother = Person(get_name(Gender.FEMALE), Gender.FEMALE) + father.add_spouse(mother) + grandfather.add_child(father) + grandmother.add_child(father) + family.update([father, mother]) + + # Add children + while len(family) < family_size: + gender = rng.choice([Gender.MALE, Gender.FEMALE]) + name = get_name(gender) + if not name: + break + child = Person(name, gender) + father.add_child(child) + mother.add_child(child) + family.add(child) + + return family + + def _get_relationship_question( + self, rng: random.Random, family: Set[Person] + ) -> Tuple[Person, Person, Relationship]: + """Select two family members and determine their relationship""" + person1, person2 = rng.sample(list(family), 2) + + # Determine relationship + if person1 in person2.parents: + relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE + else Relationship.FATHER) + elif person2 in person1.parents: + relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE + else Relationship.SON) + elif person1.spouse == person2: + relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE + else Relationship.HUSBAND) + elif (person1.parents and person2.parents and + set(person1.parents) == set(person2.parents)): + relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE + else Relationship.BROTHER) + elif (person1 in [p for parent in person2.parents for p in parent.parents]): + relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE + else Relationship.GRANDFATHER) + else: + # Try again with different people + return self._get_relationship_question(rng, family) + + return person1, person2, relationship + + def _generate_story(self, family: Set[Person]) -> str: + """Generate a story describing the family relationships""" + story_parts = [] + + # Find married couples + couples = set() + for person in family: + if person.spouse and (person.spouse, person) not in couples: + couples.add((person, person.spouse)) + + # Describe marriages + for person1, person2 in couples: + story_parts.append(f"{person1.name} is married to {person2.name}.") + + # Describe parent-child relationships + for person in family: + if person.children: + children_names = [c.name for c in person.children] + if len(children_names) == 1: + story_parts.append( + f"They have a child called {children_names[0]}." + ) + else: + *first, last = children_names + children_str = ", ".join(first) + f" and {last}" + story_parts.append( + f"They have children called {children_str}." + ) + + return " ".join(story_parts) + + +def family_relationships_dataset( + min_family_size: int = 4, + max_family_size: int = 8, + male_names: List[str] = None, + female_names: List[str] = None, + seed: Optional[int] = None, + size: int = 500, +) -> FamilyRelationshipsDataset: + """Create a FamilyRelationshipsDataset with the given configuration""" + config = FamilyRelationshipsConfig( + min_family_size=min_family_size, + max_family_size=max_family_size, + male_names=male_names, + female_names=female_names, + seed=seed, + size=size, + ) + return FamilyRelationshipsDataset(config)