mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
feat: Add family relationships dataset with configurable family trees
This commit is contained in:
parent
51cf6bdd45
commit
6e1845320b
4 changed files with 258 additions and 2 deletions
|
|
@ -31,6 +31,9 @@ The goal is to generate virtually infinite data with adjustable complexity.
|
||||||
#### Logic Tasks
|
#### Logic Tasks
|
||||||
- `PropositionalLogicDataset`: Generate propositional logic reasoning problems
|
- `PropositionalLogicDataset`: Generate propositional logic reasoning problems
|
||||||
|
|
||||||
|
#### Relationship Tasks
|
||||||
|
- `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees
|
||||||
|
|
||||||
#### Game Tasks
|
#### Game Tasks
|
||||||
- `SudokuDataset`: Generate 9x9 Sudoku puzzles with configurable number of empty cells
|
- `SudokuDataset`: Generate 9x9 Sudoku puzzles with configurable number of empty cells
|
||||||
- `MiniSudokuDataset`: Generate 4x4 Mini Sudoku puzzles with configurable difficulty
|
- `MiniSudokuDataset`: Generate 4x4 Mini Sudoku puzzles with configurable difficulty
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import algorithmic, algebra, arithmetic, cognition, data, games, logic
|
from . import algorithmic, algebra, arithmetic, cognition, data, games, logic, relationships
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.1"
|
||||||
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "logic"]
|
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "logic", "relationships"]
|
||||||
|
|
|
||||||
11
reasoning_gym/relationships/__init__.py
Normal file
11
reasoning_gym/relationships/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from .family_relationships import (
|
||||||
|
FamilyRelationshipsDataset,
|
||||||
|
FamilyRelationshipsConfig,
|
||||||
|
family_relationships_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FamilyRelationshipsDataset",
|
||||||
|
"FamilyRelationshipsConfig",
|
||||||
|
"family_relationships_dataset"
|
||||||
|
]
|
||||||
242
reasoning_gym/relationships/family_relationships.py
Normal file
242
reasoning_gym/relationships/family_relationships.py
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Dict, List, Set, Tuple
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Gender(Enum):
|
||||||
|
MALE = "male"
|
||||||
|
FEMALE = "female"
|
||||||
|
|
||||||
|
|
||||||
|
class Relationship(Enum):
|
||||||
|
MOTHER = "Mother"
|
||||||
|
FATHER = "Father"
|
||||||
|
SISTER = "Sister"
|
||||||
|
BROTHER = "Brother"
|
||||||
|
DAUGHTER = "Daughter"
|
||||||
|
SON = "Son"
|
||||||
|
WIFE = "Wife"
|
||||||
|
HUSBAND = "Husband"
|
||||||
|
GRANDMOTHER = "Grandmother"
|
||||||
|
GRANDFATHER = "Grandfather"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Person:
|
||||||
|
name: str
|
||||||
|
gender: Gender
|
||||||
|
spouse: Optional['Person'] = None
|
||||||
|
parents: List['Person'] = None
|
||||||
|
children: List['Person'] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.parents = self.parents or []
|
||||||
|
self.children = self.children or []
|
||||||
|
|
||||||
|
def add_child(self, child: 'Person'):
|
||||||
|
if child not in self.children:
|
||||||
|
self.children.append(child)
|
||||||
|
if self not in child.parents:
|
||||||
|
child.parents.append(self)
|
||||||
|
|
||||||
|
def add_spouse(self, spouse: 'Person'):
|
||||||
|
self.spouse = spouse
|
||||||
|
spouse.spouse = self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FamilyRelationshipsConfig:
|
||||||
|
"""Configuration for family relationship task generation"""
|
||||||
|
min_family_size: int = 4
|
||||||
|
max_family_size: int = 8
|
||||||
|
male_names: List[str] = None
|
||||||
|
female_names: List[str] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Default name lists if none provided
|
||||||
|
self.male_names = self.male_names or [
|
||||||
|
"James", "John", "Robert", "Michael", "William", "David", "Richard",
|
||||||
|
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew"
|
||||||
|
]
|
||||||
|
self.female_names = self.female_names or [
|
||||||
|
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
||||||
|
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna"
|
||||||
|
]
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.min_family_size >= 3, "min_family_size must be at least 3"
|
||||||
|
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
|
||||||
|
assert len(self.male_names) > 0, "must provide male names"
|
||||||
|
assert len(self.female_names) > 0, "must provide female names"
|
||||||
|
|
||||||
|
|
||||||
|
class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
|
"""Generates family relationship reasoning tasks"""
|
||||||
|
|
||||||
|
def __init__(self, config: FamilyRelationshipsConfig):
|
||||||
|
self.config = config
|
||||||
|
self.config.validate()
|
||||||
|
self._templates = [
|
||||||
|
"What is {person1} to {person2}?",
|
||||||
|
"How is {person1} related to {person2}?",
|
||||||
|
"What relation is {person1} to {person2}?",
|
||||||
|
]
|
||||||
|
super().__init__(seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Generate family tree
|
||||||
|
family = self._generate_family(rng)
|
||||||
|
|
||||||
|
# Select two people and their relationship
|
||||||
|
person1, person2, relationship = self._get_relationship_question(rng, family)
|
||||||
|
|
||||||
|
# Generate story describing the family relationships
|
||||||
|
story = self._generate_story(family)
|
||||||
|
|
||||||
|
# Format question
|
||||||
|
question = rng.choice(self._templates).format(
|
||||||
|
person1=person1.name,
|
||||||
|
person2=person2.name
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": f"{story}\n\n{question}",
|
||||||
|
"answer": relationship.value,
|
||||||
|
"metadata": {
|
||||||
|
"person1": person1.name,
|
||||||
|
"person2": person2.name,
|
||||||
|
"relationship": relationship.value,
|
||||||
|
"family_size": len(family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_family(self, rng: random.Random) -> Set[Person]:
|
||||||
|
"""Generate a random family tree"""
|
||||||
|
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
||||||
|
family = set()
|
||||||
|
used_names = set()
|
||||||
|
|
||||||
|
def get_name(gender: Gender) -> str:
|
||||||
|
names = (self.config.male_names if gender == Gender.MALE
|
||||||
|
else self.config.female_names)
|
||||||
|
available = [n for n in names if n not in used_names]
|
||||||
|
if not available:
|
||||||
|
return None
|
||||||
|
name = rng.choice(available)
|
||||||
|
used_names.add(name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
# Create grandparents generation
|
||||||
|
grandfather = Person(get_name(Gender.MALE), Gender.MALE)
|
||||||
|
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
|
||||||
|
grandfather.add_spouse(grandmother)
|
||||||
|
family.update([grandfather, grandmother])
|
||||||
|
|
||||||
|
# Create parents
|
||||||
|
father = Person(get_name(Gender.MALE), Gender.MALE)
|
||||||
|
mother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
|
||||||
|
father.add_spouse(mother)
|
||||||
|
grandfather.add_child(father)
|
||||||
|
grandmother.add_child(father)
|
||||||
|
family.update([father, mother])
|
||||||
|
|
||||||
|
# Add children
|
||||||
|
while len(family) < family_size:
|
||||||
|
gender = rng.choice([Gender.MALE, Gender.FEMALE])
|
||||||
|
name = get_name(gender)
|
||||||
|
if not name:
|
||||||
|
break
|
||||||
|
child = Person(name, gender)
|
||||||
|
father.add_child(child)
|
||||||
|
mother.add_child(child)
|
||||||
|
family.add(child)
|
||||||
|
|
||||||
|
return family
|
||||||
|
|
||||||
|
def _get_relationship_question(
|
||||||
|
self, rng: random.Random, family: Set[Person]
|
||||||
|
) -> Tuple[Person, Person, Relationship]:
|
||||||
|
"""Select two family members and determine their relationship"""
|
||||||
|
person1, person2 = rng.sample(list(family), 2)
|
||||||
|
|
||||||
|
# Determine relationship
|
||||||
|
if person1 in person2.parents:
|
||||||
|
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
|
||||||
|
else Relationship.FATHER)
|
||||||
|
elif person2 in person1.parents:
|
||||||
|
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
|
||||||
|
else Relationship.SON)
|
||||||
|
elif person1.spouse == person2:
|
||||||
|
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
|
||||||
|
else Relationship.HUSBAND)
|
||||||
|
elif (person1.parents and person2.parents and
|
||||||
|
set(person1.parents) == set(person2.parents)):
|
||||||
|
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
|
||||||
|
else Relationship.BROTHER)
|
||||||
|
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
|
||||||
|
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
|
||||||
|
else Relationship.GRANDFATHER)
|
||||||
|
else:
|
||||||
|
# Try again with different people
|
||||||
|
return self._get_relationship_question(rng, family)
|
||||||
|
|
||||||
|
return person1, person2, relationship
|
||||||
|
|
||||||
|
def _generate_story(self, family: Set[Person]) -> str:
|
||||||
|
"""Generate a story describing the family relationships"""
|
||||||
|
story_parts = []
|
||||||
|
|
||||||
|
# Find married couples
|
||||||
|
couples = set()
|
||||||
|
for person in family:
|
||||||
|
if person.spouse and (person.spouse, person) not in couples:
|
||||||
|
couples.add((person, person.spouse))
|
||||||
|
|
||||||
|
# Describe marriages
|
||||||
|
for person1, person2 in couples:
|
||||||
|
story_parts.append(f"{person1.name} is married to {person2.name}.")
|
||||||
|
|
||||||
|
# Describe parent-child relationships
|
||||||
|
for person in family:
|
||||||
|
if person.children:
|
||||||
|
children_names = [c.name for c in person.children]
|
||||||
|
if len(children_names) == 1:
|
||||||
|
story_parts.append(
|
||||||
|
f"They have a child called {children_names[0]}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
*first, last = children_names
|
||||||
|
children_str = ", ".join(first) + f" and {last}"
|
||||||
|
story_parts.append(
|
||||||
|
f"They have children called {children_str}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return " ".join(story_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def family_relationships_dataset(
|
||||||
|
min_family_size: int = 4,
|
||||||
|
max_family_size: int = 8,
|
||||||
|
male_names: List[str] = None,
|
||||||
|
female_names: List[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
size: int = 500,
|
||||||
|
) -> FamilyRelationshipsDataset:
|
||||||
|
"""Create a FamilyRelationshipsDataset with the given configuration"""
|
||||||
|
config = FamilyRelationshipsConfig(
|
||||||
|
min_family_size=min_family_size,
|
||||||
|
max_family_size=max_family_size,
|
||||||
|
male_names=male_names,
|
||||||
|
female_names=female_names,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
return FamilyRelationshipsDataset(config)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue