diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 52741f8c..4bc56162 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,7 +2,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import algorithmic, algebra, arithmetic, cognition, data, games, graphs, logic +from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic __version__ = "0.1.1" __all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"] diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 85574d60..251f6583 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,3 +1,3 @@ -from .simple_equations import SimpleEquationsDataset, SimpleEquationsConfig, simple_equations_dataset +from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset __all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"] diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index feb65c64..8b90af72 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -1,10 +1,10 @@ import random +import string from dataclasses import dataclass from typing import Optional, Tuple -import string import sympy -from sympy import Symbol, solve, Eq +from sympy import Eq, Symbol, solve from ..dataset import ProceduralDataset @@ -12,11 +12,12 @@ from ..dataset import ProceduralDataset @dataclass class SimpleEquationsConfig: """Configuration for simple equation task generation""" + min_terms: int = 2 # Minimum number of terms in expression max_terms: int = 4 # Maximum number of terms min_value: int = 1 # Minimum value for constants max_value: int = 100 # Maximum value for constants - operators: tuple = ('+', '-', '*') # Allowed operators + operators: tuple = ("+", "-", "*") # Allowed operators seed: Optional[int] = None size: int = 500 @@ -44,7 +45,7 @@ class SimpleEquationsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single equation task - + Returns: dict with keys: - question: str, the equation to solve (e.g. "3 * x = 12") @@ -52,18 +53,18 @@ class SimpleEquationsDataset(ProceduralDataset): - metadata: dict with generation parameters """ rng = random.Random(self.seed + idx) - + # Get variable and generate equation variable = self._get_variable(rng) equation, solution = self._generate_equation(rng, variable) - + return { "question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation), "answer": str(solution), "metadata": { "equation": equation, "variable": variable, - } + }, } def _get_variable(self, rng: random.Random) -> str: @@ -72,60 +73,60 @@ class SimpleEquationsDataset(ProceduralDataset): def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]: """Generate an equation and its solution - + Args: rng: Random number generator variable: Variable symbol to use in equation - + Returns: Tuple of (equation string, solution integer) """ x = Symbol(variable) - + # Generate terms for left side num_terms = rng.randint(self.config.min_terms, self.config.max_terms) terms = [] - + # Generate all constant terms first for _ in range(num_terms): value = rng.randint(self.config.min_value, self.config.max_value) terms.append(value) - + # Replace one random term with the variable term var_pos = rng.randint(0, num_terms - 1) coef = rng.randint(self.config.min_value, self.config.max_value) terms[var_pos] = coef * x - + # Apply operators between terms expr = terms[0] for i in range(1, num_terms): op = rng.choice(self.config.operators) - if op == '+': + if op == "+": expr = expr + terms[i] - elif op == '-': + elif op == "-": expr = expr - terms[i] else: # '*' expr = expr * terms[i] - + left_side = expr - + # Generate right side right_side = rng.randint(self.config.min_value, self.config.max_value) - + # Create equation equation = Eq(left_side, right_side) solutions = solve(equation, x) - + # Check if we found any solutions if not solutions: return self._generate_equation(rng, variable) - + solution = solutions[0] - + # Only return if solution is a positive integer if not (isinstance(solution, sympy.Integer) and solution > 0): return self._generate_equation(rng, variable) - + return f"{left_side} = {right_side}", int(solution) @@ -134,7 +135,7 @@ def simple_equations_dataset( max_terms: int = 5, min_value: int = 1, max_value: int = 100, - operators: tuple = ('+', '-', '*'), + operators: tuple = ("+", "-", "*"), seed: Optional[int] = None, size: int = 500, ) -> SimpleEquationsDataset: diff --git a/reasoning_gym/algebra/test.py b/reasoning_gym/algebra/test.py index f038c37d..a21110b0 100644 --- a/reasoning_gym/algebra/test.py +++ b/reasoning_gym/algebra/test.py @@ -1,50 +1,44 @@ import pytest + from .simple_equations import simple_equations_dataset def test_simple_equations_generation(): dataset = simple_equations_dataset(seed=42, size=10) - + for item in dataset: # Check required keys exist assert "question" in item assert "answer" in item assert "metadata" in item - + # Validate answer is a string of digits assert item["answer"].isdigit() - + # Validate equation format equation = item["metadata"]["equation"] variable = item["metadata"]["variable"] assert "=" in equation assert variable in equation - + # Validate question format question = item["question"] assert variable in question assert equation in question - assert any( - prompt in question - for prompt in [ - "Find the value of", - "Solve for", - "Determine the value of" - ] - ) + assert any(prompt in question for prompt in ["Find the value of", "Solve for", "Determine the value of"]) def test_simple_equations_config(): # Test invalid config raises assertion with pytest.raises(AssertionError): dataset = simple_equations_dataset(min_terms=0) - + with pytest.raises(AssertionError): dataset = simple_equations_dataset(max_terms=1, min_terms=2) - + with pytest.raises(AssertionError): dataset = simple_equations_dataset(min_value=0) - + with pytest.raises(AssertionError): dataset = simple_equations_dataset(operators=()) @@ -52,7 +46,7 @@ def test_simple_equations_config(): def test_deterministic_generation(): dataset1 = simple_equations_dataset(seed=42, size=5) dataset2 = simple_equations_dataset(seed=42, size=5) - + for i in range(5): assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["answer"] == dataset2[i]["answer"] diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index fba00856..465b9b2c 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -5,10 +5,10 @@ from dataclasses import dataclass from random import Random from typing import List, Optional -from ..dataset import ProceduralDataset - from reasoning_gym.data import read_data_file +from ..dataset import ProceduralDataset + @dataclass class LetterCountingConfig: diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index 6e761d01..c67e9cc8 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -39,7 +39,7 @@ class GCDDataset(ProceduralDataset): def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: """Generate a list of random positive integers and their GCD. Will try up to 3 times to find numbers with GCD > 1.""" - + # Try up to 3 times to get GCD > 1 for _ in range(3): num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) @@ -47,7 +47,7 @@ class GCDDataset(ProceduralDataset): result = reduce(gcd, numbers) if result > 1: break - + # Return the last generated numbers, whether they met the criteria or not return numbers, result diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index 1f0a2a6d..ad0983d4 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -50,7 +50,7 @@ class LCMDataset(ProceduralDataset): result = reduce(lcm, numbers) if result < calculate_product(numbers): break - + # Return the last generated numbers, whether they met the criteria or not return numbers, result diff --git a/reasoning_gym/relationships/__init__.py b/reasoning_gym/graphs/__init__.py similarity index 100% rename from reasoning_gym/relationships/__init__.py rename to reasoning_gym/graphs/__init__.py diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index ade7ca03..614a2128 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,8 +1,8 @@ import random from dataclasses import dataclass -from typing import Optional, Dict, List, Set, Tuple from enum import Enum from itertools import count +from typing import Dict, List, Optional, Set, Tuple from ..dataset import ProceduralDataset @@ -30,9 +30,9 @@ class Person: name: str gender: Gender id: int - spouse: Optional['Person'] = None - parents: List['Person'] = None - children: List['Person'] = None + spouse: Optional["Person"] = None + parents: List["Person"] = None + children: List["Person"] = None def __post_init__(self): self.parents = self.parents or [] @@ -46,13 +46,13 @@ class Person: return False return self.id == other.id - def add_child(self, child: 'Person'): + def add_child(self, child: "Person"): if child not in self.children: self.children.append(child) if self not in child.parents: child.parents.append(self) - def add_spouse(self, spouse: 'Person'): + def add_spouse(self, spouse: "Person"): self.spouse = spouse spouse.spouse = self @@ -60,6 +60,7 @@ class Person: @dataclass class FamilyRelationshipsConfig: """Configuration for family relationship task generation""" + min_family_size: int = 4 max_family_size: int = 8 male_names: List[str] = None @@ -70,22 +71,96 @@ class FamilyRelationshipsConfig: def __post_init__(self): # Default name lists if none provided default_male_names = [ - "James", "John", "Robert", "Michael", "William", "David", "Richard", - "Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew", - "Christopher", "Andrew", "George", "Edward", "Benjamin", "Henry", - "Samuel", "Alexander", "Oliver", "Jack", "Harry", "Jacob", - "Noah", "Ethan", "Lucas", "Mason", "Logan", "Sebastian", "Theodore", "Owen", - "Liam", "Aiden", "Kai", "Jayden", "Zion", "Phoenix", "Atlas", "Axel", "Ryder", "Finn" + "James", + "John", + "Robert", + "Michael", + "William", + "David", + "Richard", + "Joseph", + "Thomas", + "Charles", + "Peter", + "Daniel", + "Matthew", + "Christopher", + "Andrew", + "George", + "Edward", + "Benjamin", + "Henry", + "Samuel", + "Alexander", + "Oliver", + "Jack", + "Harry", + "Jacob", + "Noah", + "Ethan", + "Lucas", + "Mason", + "Logan", + "Sebastian", + "Theodore", + "Owen", + "Liam", + "Aiden", + "Kai", + "Jayden", + "Zion", + "Phoenix", + "Atlas", + "Axel", + "Ryder", + "Finn", ] default_female_names = [ - "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", - "Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", - "Margaret", "Victoria", "Charlotte", "Sophia", "Isabella", "Olivia", - "Ava", "Mia", "Emily", "Abigail", "Amelia", "Eleanor", "Grace", - "Alice", "Lucy", "Chloe", "Sophie", "Lily", "Hannah", "Zoe", - "Luna", "Nova", "Aria", "Willow", "Aurora", "Sage", "River", "Winter", "Sky", "Rain" + "Mary", + "Patricia", + "Jennifer", + "Linda", + "Elizabeth", + "Barbara", + "Susan", + "Jessica", + "Sarah", + "Karen", + "Emma", + "Lisa", + "Anna", + "Margaret", + "Victoria", + "Charlotte", + "Sophia", + "Isabella", + "Olivia", + "Ava", + "Mia", + "Emily", + "Abigail", + "Amelia", + "Eleanor", + "Grace", + "Alice", + "Lucy", + "Chloe", + "Sophie", + "Lily", + "Hannah", + "Zoe", + "Luna", + "Nova", + "Aria", + "Willow", + "Aurora", + "Sage", + "River", + "Winter", + "Sky", + "Rain", ] - + if self.male_names is None: self.male_names = default_male_names if self.female_names is None: @@ -114,22 +189,19 @@ class FamilyRelationshipsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: rng = random.Random(self.seed + idx) - + # Generate family tree family = self._generate_family(rng) - + # Select two people and their relationship person1, person2, relationship = self._get_relationship_question(rng, family) - + # Generate story describing the family relationships story = self._generate_story(family) - + # Format question - question = rng.choice(self._templates).format( - person1=person1.name, - person2=person2.name - ) - + question = rng.choice(self._templates).format(person1=person1.name, person2=person2.name) + return { "question": f"{story}\n\n{question}", "answer": relationship.value, @@ -137,8 +209,8 @@ class FamilyRelationshipsDataset(ProceduralDataset): "person1": person1.name, "person2": person2.name, "relationship": relationship.value, - "family_size": len(family) - } + "family_size": len(family), + }, } def _generate_family(self, rng: random.Random) -> Set[Person]: @@ -148,8 +220,7 @@ class FamilyRelationshipsDataset(ProceduralDataset): used_names = set() def get_name(gender: Gender) -> str: - names = (self.config.male_names if gender == Gender.MALE - else self.config.female_names) + names = self.config.male_names if gender == Gender.MALE else self.config.female_names available = [n for n in names if n not in used_names] if not available: return None @@ -159,7 +230,7 @@ class FamilyRelationshipsDataset(ProceduralDataset): # Create ID counter id_counter = count() - + # Create grandparents generation grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) @@ -192,62 +263,52 @@ class FamilyRelationshipsDataset(ProceduralDataset): ) -> Tuple[Person, Person, Relationship]: """Select two family members and determine their relationship""" person1, person2 = rng.sample(list(family), 2) - + # Determine relationship if person1 in person2.parents: - relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE - else Relationship.FATHER) + relationship = Relationship.MOTHER if person1.gender == Gender.FEMALE else Relationship.FATHER elif person2 in person1.parents: - relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE - else Relationship.SON) + relationship = Relationship.DAUGHTER if person1.gender == Gender.FEMALE else Relationship.SON elif person1.spouse == person2: - relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE - else Relationship.HUSBAND) - elif (person1.parents and person2.parents and - set(person1.parents) == set(person2.parents)): - relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE - else Relationship.BROTHER) - elif (person1 in [p for parent in person2.parents for p in parent.parents]): - relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE - else Relationship.GRANDFATHER) + relationship = Relationship.WIFE if person1.gender == Gender.FEMALE else Relationship.HUSBAND + elif person1.parents and person2.parents and set(person1.parents) == set(person2.parents): + relationship = Relationship.SISTER if person1.gender == Gender.FEMALE else Relationship.BROTHER + elif person1 in [p for parent in person2.parents for p in parent.parents]: + relationship = Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE else Relationship.GRANDFATHER else: # Try again with different people return self._get_relationship_question(rng, family) - + return person1, person2, relationship def _generate_story(self, family: Set[Person]) -> str: """Generate a story describing the family relationships""" story_parts = [] - + # Find married couples couples = set() for person in family: if person.spouse and (person.spouse, person) not in couples: couples.add((person, person.spouse)) - + # Describe marriages and children for each couple described_children = set() # Track which children have been described for person1, person2 in couples: story_parts.append(f"{person1.name} is married to {person2.name}.") - + # Only describe children once per couple children = [c for c in person1.children if c not in described_children] if children: children_names = [c.name for c in children] described_children.update(children) # Mark these children as described - + if len(children_names) == 1: - story_parts.append( - f"They have a child called {children_names[0]}." - ) + story_parts.append(f"They have a child called {children_names[0]}.") else: *first, last = children_names children_str = ", ".join(first) + f" and {last}" - story_parts.append( - f"They have children called {children_str}." - ) - + story_parts.append(f"They have children called {children_str}.") + return " ".join(story_parts) diff --git a/reasoning_gym/relationships/family_relationships.py b/reasoning_gym/relationships/family_relationships.py deleted file mode 100644 index a19d500d..00000000 --- a/reasoning_gym/relationships/family_relationships.py +++ /dev/null @@ -1,242 +0,0 @@ -import random -from dataclasses import dataclass -from typing import Optional, Dict, List, Set, Tuple -from enum import Enum - -from ..dataset import ProceduralDataset - - -class Gender(Enum): - MALE = "male" - FEMALE = "female" - - -class Relationship(Enum): - MOTHER = "Mother" - FATHER = "Father" - SISTER = "Sister" - BROTHER = "Brother" - DAUGHTER = "Daughter" - SON = "Son" - WIFE = "Wife" - HUSBAND = "Husband" - GRANDMOTHER = "Grandmother" - GRANDFATHER = "Grandfather" - - -@dataclass -class Person: - name: str - gender: Gender - spouse: Optional['Person'] = None - parents: List['Person'] = None - children: List['Person'] = None - - def __post_init__(self): - self.parents = self.parents or [] - self.children = self.children or [] - - def add_child(self, child: 'Person'): - if child not in self.children: - self.children.append(child) - if self not in child.parents: - child.parents.append(self) - - def add_spouse(self, spouse: 'Person'): - self.spouse = spouse - spouse.spouse = self - - -@dataclass -class FamilyRelationshipsConfig: - """Configuration for family relationship task generation""" - min_family_size: int = 4 - max_family_size: int = 8 - male_names: List[str] = None - female_names: List[str] = None - seed: Optional[int] = None - size: int = 500 - - def __post_init__(self): - # Default name lists if none provided - self.male_names = self.male_names or [ - "James", "John", "Robert", "Michael", "William", "David", "Richard", - "Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew" - ] - self.female_names = self.female_names or [ - "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", - "Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna" - ] - - def validate(self): - """Validate configuration parameters""" - assert self.min_family_size >= 3, "min_family_size must be at least 3" - assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size" - assert len(self.male_names) > 0, "must provide male names" - assert len(self.female_names) > 0, "must provide female names" - - -class FamilyRelationshipsDataset(ProceduralDataset): - """Generates family relationship reasoning tasks""" - - def __init__(self, config: FamilyRelationshipsConfig): - self.config = config - self.config.validate() - self._templates = [ - "What is {person1} to {person2}?", - "How is {person1} related to {person2}?", - "What relation is {person1} to {person2}?", - ] - super().__init__(seed=config.seed, size=config.size) - - def __getitem__(self, idx: int) -> dict: - rng = random.Random(self.seed + idx) - - # Generate family tree - family = self._generate_family(rng) - - # Select two people and their relationship - person1, person2, relationship = self._get_relationship_question(rng, family) - - # Generate story describing the family relationships - story = self._generate_story(family) - - # Format question - question = rng.choice(self._templates).format( - person1=person1.name, - person2=person2.name - ) - - return { - "question": f"{story}\n\n{question}", - "answer": relationship.value, - "metadata": { - "person1": person1.name, - "person2": person2.name, - "relationship": relationship.value, - "family_size": len(family) - } - } - - def _generate_family(self, rng: random.Random) -> Set[Person]: - """Generate a random family tree""" - family_size = rng.randint(self.config.min_family_size, self.config.max_family_size) - family = set() - used_names = set() - - def get_name(gender: Gender) -> str: - names = (self.config.male_names if gender == Gender.MALE - else self.config.female_names) - available = [n for n in names if n not in used_names] - if not available: - return None - name = rng.choice(available) - used_names.add(name) - return name - - # Create grandparents generation - grandfather = Person(get_name(Gender.MALE), Gender.MALE) - grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE) - grandfather.add_spouse(grandmother) - family.update([grandfather, grandmother]) - - # Create parents - father = Person(get_name(Gender.MALE), Gender.MALE) - mother = Person(get_name(Gender.FEMALE), Gender.FEMALE) - father.add_spouse(mother) - grandfather.add_child(father) - grandmother.add_child(father) - family.update([father, mother]) - - # Add children - while len(family) < family_size: - gender = rng.choice([Gender.MALE, Gender.FEMALE]) - name = get_name(gender) - if not name: - break - child = Person(name, gender) - father.add_child(child) - mother.add_child(child) - family.add(child) - - return family - - def _get_relationship_question( - self, rng: random.Random, family: Set[Person] - ) -> Tuple[Person, Person, Relationship]: - """Select two family members and determine their relationship""" - person1, person2 = rng.sample(list(family), 2) - - # Determine relationship - if person1 in person2.parents: - relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE - else Relationship.FATHER) - elif person2 in person1.parents: - relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE - else Relationship.SON) - elif person1.spouse == person2: - relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE - else Relationship.HUSBAND) - elif (person1.parents and person2.parents and - set(person1.parents) == set(person2.parents)): - relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE - else Relationship.BROTHER) - elif (person1 in [p for parent in person2.parents for p in parent.parents]): - relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE - else Relationship.GRANDFATHER) - else: - # Try again with different people - return self._get_relationship_question(rng, family) - - return person1, person2, relationship - - def _generate_story(self, family: Set[Person]) -> str: - """Generate a story describing the family relationships""" - story_parts = [] - - # Find married couples - couples = set() - for person in family: - if person.spouse and (person.spouse, person) not in couples: - couples.add((person, person.spouse)) - - # Describe marriages - for person1, person2 in couples: - story_parts.append(f"{person1.name} is married to {person2.name}.") - - # Describe parent-child relationships - for person in family: - if person.children: - children_names = [c.name for c in person.children] - if len(children_names) == 1: - story_parts.append( - f"They have a child called {children_names[0]}." - ) - else: - *first, last = children_names - children_str = ", ".join(first) + f" and {last}" - story_parts.append( - f"They have children called {children_str}." - ) - - return " ".join(story_parts) - - -def family_relationships_dataset( - min_family_size: int = 4, - max_family_size: int = 8, - male_names: List[str] = None, - female_names: List[str] = None, - seed: Optional[int] = None, - size: int = 500, -) -> FamilyRelationshipsDataset: - """Create a FamilyRelationshipsDataset with the given configuration""" - config = FamilyRelationshipsConfig( - min_family_size=min_family_size, - max_family_size=max_family_size, - male_names=male_names, - female_names=female_names, - seed=seed, - size=size, - ) - return FamilyRelationshipsDataset(config) diff --git a/tests/test_family_relationships.py b/tests/test_family_relationships.py index 0898153e..80d2d9f0 100644 --- a/tests/test_family_relationships.py +++ b/tests/test_family_relationships.py @@ -1,35 +1,26 @@ -from reasoning_gym.graphs.family_relationships import ( - family_relationships_dataset, - Gender, - Relationship, -) +import pytest + +from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset def test_family_relationships_generation(): dataset = family_relationships_dataset(seed=42, size=10) - + for item in dataset: # Check required keys exist assert "question" in item assert "answer" in item assert "metadata" in item - + # Validate story and question format story_and_question = item["question"] assert "is married to" in story_and_question assert "have" in story_and_question - assert any( - prompt in story_and_question - for prompt in [ - "What is", - "How is", - "What relation is" - ] - ) - + assert any(prompt in story_and_question for prompt in ["What is", "How is", "What relation is"]) + # Validate answer is a valid relationship assert item["answer"] in [r.value for r in Relationship] - + # Validate metadata assert "person1" in item["metadata"] assert "person2" in item["metadata"] @@ -42,13 +33,13 @@ def test_family_relationships_config(): # Test invalid config raises assertion with pytest.raises(AssertionError): dataset = family_relationships_dataset(min_family_size=2) - + with pytest.raises(AssertionError): dataset = family_relationships_dataset(max_family_size=3, min_family_size=4) - + with pytest.raises(AssertionError): dataset = family_relationships_dataset(male_names=[]) - + with pytest.raises(AssertionError): dataset = family_relationships_dataset(female_names=[]) @@ -56,7 +47,7 @@ def test_family_relationships_config(): def test_deterministic_generation(): dataset1 = family_relationships_dataset(seed=42, size=5) dataset2 = family_relationships_dataset(seed=42, size=5) - + for i in range(5): assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["answer"] == dataset2[i]["answer"] @@ -64,15 +55,23 @@ def test_deterministic_generation(): def test_relationship_consistency(): dataset = family_relationships_dataset(seed=42, size=10) - + for item in dataset: # Check that relationship matches the gender relationship = item["metadata"]["relationship"] - if relationship in [Relationship.MOTHER.value, Relationship.GRANDMOTHER.value, - Relationship.WIFE.value, Relationship.SISTER.value, - Relationship.DAUGHTER.value]: + if relationship in [ + Relationship.MOTHER.value, + Relationship.GRANDMOTHER.value, + Relationship.WIFE.value, + Relationship.SISTER.value, + Relationship.DAUGHTER.value, + ]: assert "married to" in item["question"] or "child" in item["question"] - elif relationship in [Relationship.FATHER.value, Relationship.GRANDFATHER.value, - Relationship.HUSBAND.value, Relationship.BROTHER.value, - Relationship.SON.value]: + elif relationship in [ + Relationship.FATHER.value, + Relationship.GRANDFATHER.value, + Relationship.HUSBAND.value, + Relationship.BROTHER.value, + Relationship.SON.value, + ]: assert "married to" in item["question"] or "child" in item["question"] diff --git a/tests/test_number_sequences.py b/tests/test_number_sequences.py index 5438f4de..69afcff3 100644 --- a/tests/test_number_sequences.py +++ b/tests/test_number_sequences.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.cognition.number_sequences import Operation, PatternRule, NumberSequenceConfig, NumberSequenceDataset +from reasoning_gym.cognition.number_sequences import NumberSequenceConfig, NumberSequenceDataset, Operation, PatternRule def test_sequence_config_validation():