formatting, cleanup

This commit is contained in:
Andreas Koepf 2025-01-24 17:12:42 +01:00
parent b767e58e48
commit 3dc80be7d2
12 changed files with 189 additions and 376 deletions

View file

@ -2,7 +2,7 @@
Reasoning Gym - A library of procedural dataset generators for training reasoning models Reasoning Gym - A library of procedural dataset generators for training reasoning models
""" """
from . import algorithmic, algebra, arithmetic, cognition, data, games, graphs, logic from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic
__version__ = "0.1.1" __version__ = "0.1.1"
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"] __all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]

View file

@ -1,3 +1,3 @@
from .simple_equations import SimpleEquationsDataset, SimpleEquationsConfig, simple_equations_dataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset
__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"] __all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"]

View file

@ -1,10 +1,10 @@
import random import random
import string
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import string
import sympy import sympy
from sympy import Symbol, solve, Eq from sympy import Eq, Symbol, solve
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -12,11 +12,12 @@ from ..dataset import ProceduralDataset
@dataclass @dataclass
class SimpleEquationsConfig: class SimpleEquationsConfig:
"""Configuration for simple equation task generation""" """Configuration for simple equation task generation"""
min_terms: int = 2 # Minimum number of terms in expression min_terms: int = 2 # Minimum number of terms in expression
max_terms: int = 4 # Maximum number of terms max_terms: int = 4 # Maximum number of terms
min_value: int = 1 # Minimum value for constants min_value: int = 1 # Minimum value for constants
max_value: int = 100 # Maximum value for constants max_value: int = 100 # Maximum value for constants
operators: tuple = ('+', '-', '*') # Allowed operators operators: tuple = ("+", "-", "*") # Allowed operators
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
@ -44,7 +45,7 @@ class SimpleEquationsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single equation task """Generate a single equation task
Returns: Returns:
dict with keys: dict with keys:
- question: str, the equation to solve (e.g. "3 * x = 12") - question: str, the equation to solve (e.g. "3 * x = 12")
@ -52,18 +53,18 @@ class SimpleEquationsDataset(ProceduralDataset):
- metadata: dict with generation parameters - metadata: dict with generation parameters
""" """
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
# Get variable and generate equation # Get variable and generate equation
variable = self._get_variable(rng) variable = self._get_variable(rng)
equation, solution = self._generate_equation(rng, variable) equation, solution = self._generate_equation(rng, variable)
return { return {
"question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation), "question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation),
"answer": str(solution), "answer": str(solution),
"metadata": { "metadata": {
"equation": equation, "equation": equation,
"variable": variable, "variable": variable,
} },
} }
def _get_variable(self, rng: random.Random) -> str: def _get_variable(self, rng: random.Random) -> str:
@ -72,60 +73,60 @@ class SimpleEquationsDataset(ProceduralDataset):
def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]: def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]:
"""Generate an equation and its solution """Generate an equation and its solution
Args: Args:
rng: Random number generator rng: Random number generator
variable: Variable symbol to use in equation variable: Variable symbol to use in equation
Returns: Returns:
Tuple of (equation string, solution integer) Tuple of (equation string, solution integer)
""" """
x = Symbol(variable) x = Symbol(variable)
# Generate terms for left side # Generate terms for left side
num_terms = rng.randint(self.config.min_terms, self.config.max_terms) num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
terms = [] terms = []
# Generate all constant terms first # Generate all constant terms first
for _ in range(num_terms): for _ in range(num_terms):
value = rng.randint(self.config.min_value, self.config.max_value) value = rng.randint(self.config.min_value, self.config.max_value)
terms.append(value) terms.append(value)
# Replace one random term with the variable term # Replace one random term with the variable term
var_pos = rng.randint(0, num_terms - 1) var_pos = rng.randint(0, num_terms - 1)
coef = rng.randint(self.config.min_value, self.config.max_value) coef = rng.randint(self.config.min_value, self.config.max_value)
terms[var_pos] = coef * x terms[var_pos] = coef * x
# Apply operators between terms # Apply operators between terms
expr = terms[0] expr = terms[0]
for i in range(1, num_terms): for i in range(1, num_terms):
op = rng.choice(self.config.operators) op = rng.choice(self.config.operators)
if op == '+': if op == "+":
expr = expr + terms[i] expr = expr + terms[i]
elif op == '-': elif op == "-":
expr = expr - terms[i] expr = expr - terms[i]
else: # '*' else: # '*'
expr = expr * terms[i] expr = expr * terms[i]
left_side = expr left_side = expr
# Generate right side # Generate right side
right_side = rng.randint(self.config.min_value, self.config.max_value) right_side = rng.randint(self.config.min_value, self.config.max_value)
# Create equation # Create equation
equation = Eq(left_side, right_side) equation = Eq(left_side, right_side)
solutions = solve(equation, x) solutions = solve(equation, x)
# Check if we found any solutions # Check if we found any solutions
if not solutions: if not solutions:
return self._generate_equation(rng, variable) return self._generate_equation(rng, variable)
solution = solutions[0] solution = solutions[0]
# Only return if solution is a positive integer # Only return if solution is a positive integer
if not (isinstance(solution, sympy.Integer) and solution > 0): if not (isinstance(solution, sympy.Integer) and solution > 0):
return self._generate_equation(rng, variable) return self._generate_equation(rng, variable)
return f"{left_side} = {right_side}", int(solution) return f"{left_side} = {right_side}", int(solution)
@ -134,7 +135,7 @@ def simple_equations_dataset(
max_terms: int = 5, max_terms: int = 5,
min_value: int = 1, min_value: int = 1,
max_value: int = 100, max_value: int = 100,
operators: tuple = ('+', '-', '*'), operators: tuple = ("+", "-", "*"),
seed: Optional[int] = None, seed: Optional[int] = None,
size: int = 500, size: int = 500,
) -> SimpleEquationsDataset: ) -> SimpleEquationsDataset:

View file

@ -1,50 +1,44 @@
import pytest import pytest
from .simple_equations import simple_equations_dataset from .simple_equations import simple_equations_dataset
def test_simple_equations_generation(): def test_simple_equations_generation():
dataset = simple_equations_dataset(seed=42, size=10) dataset = simple_equations_dataset(seed=42, size=10)
for item in dataset: for item in dataset:
# Check required keys exist # Check required keys exist
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Validate answer is a string of digits # Validate answer is a string of digits
assert item["answer"].isdigit() assert item["answer"].isdigit()
# Validate equation format # Validate equation format
equation = item["metadata"]["equation"] equation = item["metadata"]["equation"]
variable = item["metadata"]["variable"] variable = item["metadata"]["variable"]
assert "=" in equation assert "=" in equation
assert variable in equation assert variable in equation
# Validate question format # Validate question format
question = item["question"] question = item["question"]
assert variable in question assert variable in question
assert equation in question assert equation in question
assert any( assert any(prompt in question for prompt in ["Find the value of", "Solve for", "Determine the value of"])
prompt in question
for prompt in [
"Find the value of",
"Solve for",
"Determine the value of"
]
)
def test_simple_equations_config(): def test_simple_equations_config():
# Test invalid config raises assertion # Test invalid config raises assertion
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = simple_equations_dataset(min_terms=0) dataset = simple_equations_dataset(min_terms=0)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = simple_equations_dataset(max_terms=1, min_terms=2) dataset = simple_equations_dataset(max_terms=1, min_terms=2)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = simple_equations_dataset(min_value=0) dataset = simple_equations_dataset(min_value=0)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = simple_equations_dataset(operators=()) dataset = simple_equations_dataset(operators=())
@ -52,7 +46,7 @@ def test_simple_equations_config():
def test_deterministic_generation(): def test_deterministic_generation():
dataset1 = simple_equations_dataset(seed=42, size=5) dataset1 = simple_equations_dataset(seed=42, size=5)
dataset2 = simple_equations_dataset(seed=42, size=5) dataset2 = simple_equations_dataset(seed=42, size=5)
for i in range(5): for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["question"] == dataset2[i]["question"]
assert dataset1[i]["answer"] == dataset2[i]["answer"] assert dataset1[i]["answer"] == dataset2[i]["answer"]

View file

@ -5,10 +5,10 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from ..dataset import ProceduralDataset
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
from ..dataset import ProceduralDataset
@dataclass @dataclass
class LetterCountingConfig: class LetterCountingConfig:

View file

@ -39,7 +39,7 @@ class GCDDataset(ProceduralDataset):
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their GCD. """Generate a list of random positive integers and their GCD.
Will try up to 3 times to find numbers with GCD > 1.""" Will try up to 3 times to find numbers with GCD > 1."""
# Try up to 3 times to get GCD > 1 # Try up to 3 times to get GCD > 1
for _ in range(3): for _ in range(3):
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
@ -47,7 +47,7 @@ class GCDDataset(ProceduralDataset):
result = reduce(gcd, numbers) result = reduce(gcd, numbers)
if result > 1: if result > 1:
break break
# Return the last generated numbers, whether they met the criteria or not # Return the last generated numbers, whether they met the criteria or not
return numbers, result return numbers, result

View file

@ -50,7 +50,7 @@ class LCMDataset(ProceduralDataset):
result = reduce(lcm, numbers) result = reduce(lcm, numbers)
if result < calculate_product(numbers): if result < calculate_product(numbers):
break break
# Return the last generated numbers, whether they met the criteria or not # Return the last generated numbers, whether they met the criteria or not
return numbers, result return numbers, result

View file

@ -1,8 +1,8 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Dict, List, Set, Tuple
from enum import Enum from enum import Enum
from itertools import count from itertools import count
from typing import Dict, List, Optional, Set, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -30,9 +30,9 @@ class Person:
name: str name: str
gender: Gender gender: Gender
id: int id: int
spouse: Optional['Person'] = None spouse: Optional["Person"] = None
parents: List['Person'] = None parents: List["Person"] = None
children: List['Person'] = None children: List["Person"] = None
def __post_init__(self): def __post_init__(self):
self.parents = self.parents or [] self.parents = self.parents or []
@ -46,13 +46,13 @@ class Person:
return False return False
return self.id == other.id return self.id == other.id
def add_child(self, child: 'Person'): def add_child(self, child: "Person"):
if child not in self.children: if child not in self.children:
self.children.append(child) self.children.append(child)
if self not in child.parents: if self not in child.parents:
child.parents.append(self) child.parents.append(self)
def add_spouse(self, spouse: 'Person'): def add_spouse(self, spouse: "Person"):
self.spouse = spouse self.spouse = spouse
spouse.spouse = self spouse.spouse = self
@ -60,6 +60,7 @@ class Person:
@dataclass @dataclass
class FamilyRelationshipsConfig: class FamilyRelationshipsConfig:
"""Configuration for family relationship task generation""" """Configuration for family relationship task generation"""
min_family_size: int = 4 min_family_size: int = 4
max_family_size: int = 8 max_family_size: int = 8
male_names: List[str] = None male_names: List[str] = None
@ -70,22 +71,96 @@ class FamilyRelationshipsConfig:
def __post_init__(self): def __post_init__(self):
# Default name lists if none provided # Default name lists if none provided
default_male_names = [ default_male_names = [
"James", "John", "Robert", "Michael", "William", "David", "Richard", "James",
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew", "John",
"Christopher", "Andrew", "George", "Edward", "Benjamin", "Henry", "Robert",
"Samuel", "Alexander", "Oliver", "Jack", "Harry", "Jacob", "Michael",
"Noah", "Ethan", "Lucas", "Mason", "Logan", "Sebastian", "Theodore", "Owen", "William",
"Liam", "Aiden", "Kai", "Jayden", "Zion", "Phoenix", "Atlas", "Axel", "Ryder", "Finn" "David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Peter",
"Daniel",
"Matthew",
"Christopher",
"Andrew",
"George",
"Edward",
"Benjamin",
"Henry",
"Samuel",
"Alexander",
"Oliver",
"Jack",
"Harry",
"Jacob",
"Noah",
"Ethan",
"Lucas",
"Mason",
"Logan",
"Sebastian",
"Theodore",
"Owen",
"Liam",
"Aiden",
"Kai",
"Jayden",
"Zion",
"Phoenix",
"Atlas",
"Axel",
"Ryder",
"Finn",
] ]
default_female_names = [ default_female_names = [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Mary",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Patricia",
"Margaret", "Victoria", "Charlotte", "Sophia", "Isabella", "Olivia", "Jennifer",
"Ava", "Mia", "Emily", "Abigail", "Amelia", "Eleanor", "Grace", "Linda",
"Alice", "Lucy", "Chloe", "Sophie", "Lily", "Hannah", "Zoe", "Elizabeth",
"Luna", "Nova", "Aria", "Willow", "Aurora", "Sage", "River", "Winter", "Sky", "Rain" "Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Emma",
"Lisa",
"Anna",
"Margaret",
"Victoria",
"Charlotte",
"Sophia",
"Isabella",
"Olivia",
"Ava",
"Mia",
"Emily",
"Abigail",
"Amelia",
"Eleanor",
"Grace",
"Alice",
"Lucy",
"Chloe",
"Sophie",
"Lily",
"Hannah",
"Zoe",
"Luna",
"Nova",
"Aria",
"Willow",
"Aurora",
"Sage",
"River",
"Winter",
"Sky",
"Rain",
] ]
if self.male_names is None: if self.male_names is None:
self.male_names = default_male_names self.male_names = default_male_names
if self.female_names is None: if self.female_names is None:
@ -114,22 +189,19 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
# Generate family tree # Generate family tree
family = self._generate_family(rng) family = self._generate_family(rng)
# Select two people and their relationship # Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family) person1, person2, relationship = self._get_relationship_question(rng, family)
# Generate story describing the family relationships # Generate story describing the family relationships
story = self._generate_story(family) story = self._generate_story(family)
# Format question # Format question
question = rng.choice(self._templates).format( question = rng.choice(self._templates).format(person1=person1.name, person2=person2.name)
person1=person1.name,
person2=person2.name
)
return { return {
"question": f"{story}\n\n{question}", "question": f"{story}\n\n{question}",
"answer": relationship.value, "answer": relationship.value,
@ -137,8 +209,8 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"person1": person1.name, "person1": person1.name,
"person2": person2.name, "person2": person2.name,
"relationship": relationship.value, "relationship": relationship.value,
"family_size": len(family) "family_size": len(family),
} },
} }
def _generate_family(self, rng: random.Random) -> Set[Person]: def _generate_family(self, rng: random.Random) -> Set[Person]:
@ -148,8 +220,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
used_names = set() used_names = set()
def get_name(gender: Gender) -> str: def get_name(gender: Gender) -> str:
names = (self.config.male_names if gender == Gender.MALE names = self.config.male_names if gender == Gender.MALE else self.config.female_names
else self.config.female_names)
available = [n for n in names if n not in used_names] available = [n for n in names if n not in used_names]
if not available: if not available:
return None return None
@ -159,7 +230,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
# Create ID counter # Create ID counter
id_counter = count() id_counter = count()
# Create grandparents generation # Create grandparents generation
grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter))
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter))
@ -192,62 +263,52 @@ class FamilyRelationshipsDataset(ProceduralDataset):
) -> Tuple[Person, Person, Relationship]: ) -> Tuple[Person, Person, Relationship]:
"""Select two family members and determine their relationship""" """Select two family members and determine their relationship"""
person1, person2 = rng.sample(list(family), 2) person1, person2 = rng.sample(list(family), 2)
# Determine relationship # Determine relationship
if person1 in person2.parents: if person1 in person2.parents:
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE relationship = Relationship.MOTHER if person1.gender == Gender.FEMALE else Relationship.FATHER
else Relationship.FATHER)
elif person2 in person1.parents: elif person2 in person1.parents:
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE relationship = Relationship.DAUGHTER if person1.gender == Gender.FEMALE else Relationship.SON
else Relationship.SON)
elif person1.spouse == person2: elif person1.spouse == person2:
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE relationship = Relationship.WIFE if person1.gender == Gender.FEMALE else Relationship.HUSBAND
else Relationship.HUSBAND) elif person1.parents and person2.parents and set(person1.parents) == set(person2.parents):
elif (person1.parents and person2.parents and relationship = Relationship.SISTER if person1.gender == Gender.FEMALE else Relationship.BROTHER
set(person1.parents) == set(person2.parents)): elif person1 in [p for parent in person2.parents for p in parent.parents]:
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE relationship = Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE else Relationship.GRANDFATHER
else Relationship.BROTHER)
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
else Relationship.GRANDFATHER)
else: else:
# Try again with different people # Try again with different people
return self._get_relationship_question(rng, family) return self._get_relationship_question(rng, family)
return person1, person2, relationship return person1, person2, relationship
def _generate_story(self, family: Set[Person]) -> str: def _generate_story(self, family: Set[Person]) -> str:
"""Generate a story describing the family relationships""" """Generate a story describing the family relationships"""
story_parts = [] story_parts = []
# Find married couples # Find married couples
couples = set() couples = set()
for person in family: for person in family:
if person.spouse and (person.spouse, person) not in couples: if person.spouse and (person.spouse, person) not in couples:
couples.add((person, person.spouse)) couples.add((person, person.spouse))
# Describe marriages and children for each couple # Describe marriages and children for each couple
described_children = set() # Track which children have been described described_children = set() # Track which children have been described
for person1, person2 in couples: for person1, person2 in couples:
story_parts.append(f"{person1.name} is married to {person2.name}.") story_parts.append(f"{person1.name} is married to {person2.name}.")
# Only describe children once per couple # Only describe children once per couple
children = [c for c in person1.children if c not in described_children] children = [c for c in person1.children if c not in described_children]
if children: if children:
children_names = [c.name for c in children] children_names = [c.name for c in children]
described_children.update(children) # Mark these children as described described_children.update(children) # Mark these children as described
if len(children_names) == 1: if len(children_names) == 1:
story_parts.append( story_parts.append(f"They have a child called {children_names[0]}.")
f"They have a child called {children_names[0]}."
)
else: else:
*first, last = children_names *first, last = children_names
children_str = ", ".join(first) + f" and {last}" children_str = ", ".join(first) + f" and {last}"
story_parts.append( story_parts.append(f"They have children called {children_str}.")
f"They have children called {children_str}."
)
return " ".join(story_parts) return " ".join(story_parts)

View file

@ -1,242 +0,0 @@
import random
from dataclasses import dataclass
from typing import Optional, Dict, List, Set, Tuple
from enum import Enum
from ..dataset import ProceduralDataset
class Gender(Enum):
MALE = "male"
FEMALE = "female"
class Relationship(Enum):
MOTHER = "Mother"
FATHER = "Father"
SISTER = "Sister"
BROTHER = "Brother"
DAUGHTER = "Daughter"
SON = "Son"
WIFE = "Wife"
HUSBAND = "Husband"
GRANDMOTHER = "Grandmother"
GRANDFATHER = "Grandfather"
@dataclass
class Person:
name: str
gender: Gender
spouse: Optional['Person'] = None
parents: List['Person'] = None
children: List['Person'] = None
def __post_init__(self):
self.parents = self.parents or []
self.children = self.children or []
def add_child(self, child: 'Person'):
if child not in self.children:
self.children.append(child)
if self not in child.parents:
child.parents.append(self)
def add_spouse(self, spouse: 'Person'):
self.spouse = spouse
spouse.spouse = self
@dataclass
class FamilyRelationshipsConfig:
"""Configuration for family relationship task generation"""
min_family_size: int = 4
max_family_size: int = 8
male_names: List[str] = None
female_names: List[str] = None
seed: Optional[int] = None
size: int = 500
def __post_init__(self):
# Default name lists if none provided
self.male_names = self.male_names or [
"James", "John", "Robert", "Michael", "William", "David", "Richard",
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew"
]
self.female_names = self.female_names or [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna"
]
def validate(self):
"""Validate configuration parameters"""
assert self.min_family_size >= 3, "min_family_size must be at least 3"
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
assert len(self.male_names) > 0, "must provide male names"
assert len(self.female_names) > 0, "must provide female names"
class FamilyRelationshipsDataset(ProceduralDataset):
"""Generates family relationship reasoning tasks"""
def __init__(self, config: FamilyRelationshipsConfig):
self.config = config
self.config.validate()
self._templates = [
"What is {person1} to {person2}?",
"How is {person1} related to {person2}?",
"What relation is {person1} to {person2}?",
]
super().__init__(seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
# Generate family tree
family = self._generate_family(rng)
# Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family)
# Generate story describing the family relationships
story = self._generate_story(family)
# Format question
question = rng.choice(self._templates).format(
person1=person1.name,
person2=person2.name
)
return {
"question": f"{story}\n\n{question}",
"answer": relationship.value,
"metadata": {
"person1": person1.name,
"person2": person2.name,
"relationship": relationship.value,
"family_size": len(family)
}
}
def _generate_family(self, rng: random.Random) -> Set[Person]:
"""Generate a random family tree"""
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
family = set()
used_names = set()
def get_name(gender: Gender) -> str:
names = (self.config.male_names if gender == Gender.MALE
else self.config.female_names)
available = [n for n in names if n not in used_names]
if not available:
return None
name = rng.choice(available)
used_names.add(name)
return name
# Create grandparents generation
grandfather = Person(get_name(Gender.MALE), Gender.MALE)
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
grandfather.add_spouse(grandmother)
family.update([grandfather, grandmother])
# Create parents
father = Person(get_name(Gender.MALE), Gender.MALE)
mother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
father.add_spouse(mother)
grandfather.add_child(father)
grandmother.add_child(father)
family.update([father, mother])
# Add children
while len(family) < family_size:
gender = rng.choice([Gender.MALE, Gender.FEMALE])
name = get_name(gender)
if not name:
break
child = Person(name, gender)
father.add_child(child)
mother.add_child(child)
family.add(child)
return family
def _get_relationship_question(
self, rng: random.Random, family: Set[Person]
) -> Tuple[Person, Person, Relationship]:
"""Select two family members and determine their relationship"""
person1, person2 = rng.sample(list(family), 2)
# Determine relationship
if person1 in person2.parents:
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
else Relationship.FATHER)
elif person2 in person1.parents:
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
else Relationship.SON)
elif person1.spouse == person2:
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
else Relationship.HUSBAND)
elif (person1.parents and person2.parents and
set(person1.parents) == set(person2.parents)):
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
else Relationship.BROTHER)
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
else Relationship.GRANDFATHER)
else:
# Try again with different people
return self._get_relationship_question(rng, family)
return person1, person2, relationship
def _generate_story(self, family: Set[Person]) -> str:
"""Generate a story describing the family relationships"""
story_parts = []
# Find married couples
couples = set()
for person in family:
if person.spouse and (person.spouse, person) not in couples:
couples.add((person, person.spouse))
# Describe marriages
for person1, person2 in couples:
story_parts.append(f"{person1.name} is married to {person2.name}.")
# Describe parent-child relationships
for person in family:
if person.children:
children_names = [c.name for c in person.children]
if len(children_names) == 1:
story_parts.append(
f"They have a child called {children_names[0]}."
)
else:
*first, last = children_names
children_str = ", ".join(first) + f" and {last}"
story_parts.append(
f"They have children called {children_str}."
)
return " ".join(story_parts)
def family_relationships_dataset(
min_family_size: int = 4,
max_family_size: int = 8,
male_names: List[str] = None,
female_names: List[str] = None,
seed: Optional[int] = None,
size: int = 500,
) -> FamilyRelationshipsDataset:
"""Create a FamilyRelationshipsDataset with the given configuration"""
config = FamilyRelationshipsConfig(
min_family_size=min_family_size,
max_family_size=max_family_size,
male_names=male_names,
female_names=female_names,
seed=seed,
size=size,
)
return FamilyRelationshipsDataset(config)

View file

@ -1,35 +1,26 @@
from reasoning_gym.graphs.family_relationships import ( import pytest
family_relationships_dataset,
Gender, from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset
Relationship,
)
def test_family_relationships_generation(): def test_family_relationships_generation():
dataset = family_relationships_dataset(seed=42, size=10) dataset = family_relationships_dataset(seed=42, size=10)
for item in dataset: for item in dataset:
# Check required keys exist # Check required keys exist
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Validate story and question format # Validate story and question format
story_and_question = item["question"] story_and_question = item["question"]
assert "is married to" in story_and_question assert "is married to" in story_and_question
assert "have" in story_and_question assert "have" in story_and_question
assert any( assert any(prompt in story_and_question for prompt in ["What is", "How is", "What relation is"])
prompt in story_and_question
for prompt in [
"What is",
"How is",
"What relation is"
]
)
# Validate answer is a valid relationship # Validate answer is a valid relationship
assert item["answer"] in [r.value for r in Relationship] assert item["answer"] in [r.value for r in Relationship]
# Validate metadata # Validate metadata
assert "person1" in item["metadata"] assert "person1" in item["metadata"]
assert "person2" in item["metadata"] assert "person2" in item["metadata"]
@ -42,13 +33,13 @@ def test_family_relationships_config():
# Test invalid config raises assertion # Test invalid config raises assertion
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(min_family_size=2) dataset = family_relationships_dataset(min_family_size=2)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(max_family_size=3, min_family_size=4) dataset = family_relationships_dataset(max_family_size=3, min_family_size=4)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(male_names=[]) dataset = family_relationships_dataset(male_names=[])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(female_names=[]) dataset = family_relationships_dataset(female_names=[])
@ -56,7 +47,7 @@ def test_family_relationships_config():
def test_deterministic_generation(): def test_deterministic_generation():
dataset1 = family_relationships_dataset(seed=42, size=5) dataset1 = family_relationships_dataset(seed=42, size=5)
dataset2 = family_relationships_dataset(seed=42, size=5) dataset2 = family_relationships_dataset(seed=42, size=5)
for i in range(5): for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["question"] == dataset2[i]["question"]
assert dataset1[i]["answer"] == dataset2[i]["answer"] assert dataset1[i]["answer"] == dataset2[i]["answer"]
@ -64,15 +55,23 @@ def test_deterministic_generation():
def test_relationship_consistency(): def test_relationship_consistency():
dataset = family_relationships_dataset(seed=42, size=10) dataset = family_relationships_dataset(seed=42, size=10)
for item in dataset: for item in dataset:
# Check that relationship matches the gender # Check that relationship matches the gender
relationship = item["metadata"]["relationship"] relationship = item["metadata"]["relationship"]
if relationship in [Relationship.MOTHER.value, Relationship.GRANDMOTHER.value, if relationship in [
Relationship.WIFE.value, Relationship.SISTER.value, Relationship.MOTHER.value,
Relationship.DAUGHTER.value]: Relationship.GRANDMOTHER.value,
Relationship.WIFE.value,
Relationship.SISTER.value,
Relationship.DAUGHTER.value,
]:
assert "married to" in item["question"] or "child" in item["question"] assert "married to" in item["question"] or "child" in item["question"]
elif relationship in [Relationship.FATHER.value, Relationship.GRANDFATHER.value, elif relationship in [
Relationship.HUSBAND.value, Relationship.BROTHER.value, Relationship.FATHER.value,
Relationship.SON.value]: Relationship.GRANDFATHER.value,
Relationship.HUSBAND.value,
Relationship.BROTHER.value,
Relationship.SON.value,
]:
assert "married to" in item["question"] or "child" in item["question"] assert "married to" in item["question"] or "child" in item["question"]

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.cognition.number_sequences import Operation, PatternRule, NumberSequenceConfig, NumberSequenceDataset from reasoning_gym.cognition.number_sequences import NumberSequenceConfig, NumberSequenceDataset, Operation, PatternRule
def test_sequence_config_validation(): def test_sequence_config_validation():