mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
formatting, cleanup
This commit is contained in:
parent
b767e58e48
commit
3dc80be7d2
12 changed files with 189 additions and 376 deletions
|
|
@ -2,7 +2,7 @@
|
||||||
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import algorithmic, algebra, arithmetic, cognition, data, games, graphs, logic
|
from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.1"
|
||||||
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]
|
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from .simple_equations import SimpleEquationsDataset, SimpleEquationsConfig, simple_equations_dataset
|
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset
|
||||||
|
|
||||||
__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"]
|
__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"]
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import random
|
import random
|
||||||
|
import string
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
import string
|
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy import Symbol, solve, Eq
|
from sympy import Eq, Symbol, solve
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
@ -12,11 +12,12 @@ from ..dataset import ProceduralDataset
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimpleEquationsConfig:
|
class SimpleEquationsConfig:
|
||||||
"""Configuration for simple equation task generation"""
|
"""Configuration for simple equation task generation"""
|
||||||
|
|
||||||
min_terms: int = 2 # Minimum number of terms in expression
|
min_terms: int = 2 # Minimum number of terms in expression
|
||||||
max_terms: int = 4 # Maximum number of terms
|
max_terms: int = 4 # Maximum number of terms
|
||||||
min_value: int = 1 # Minimum value for constants
|
min_value: int = 1 # Minimum value for constants
|
||||||
max_value: int = 100 # Maximum value for constants
|
max_value: int = 100 # Maximum value for constants
|
||||||
operators: tuple = ('+', '-', '*') # Allowed operators
|
operators: tuple = ("+", "-", "*") # Allowed operators
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
|
|
@ -44,7 +45,7 @@ class SimpleEquationsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single equation task
|
"""Generate a single equation task
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with keys:
|
dict with keys:
|
||||||
- question: str, the equation to solve (e.g. "3 * x = 12")
|
- question: str, the equation to solve (e.g. "3 * x = 12")
|
||||||
|
|
@ -52,18 +53,18 @@ class SimpleEquationsDataset(ProceduralDataset):
|
||||||
- metadata: dict with generation parameters
|
- metadata: dict with generation parameters
|
||||||
"""
|
"""
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
# Get variable and generate equation
|
# Get variable and generate equation
|
||||||
variable = self._get_variable(rng)
|
variable = self._get_variable(rng)
|
||||||
equation, solution = self._generate_equation(rng, variable)
|
equation, solution = self._generate_equation(rng, variable)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation),
|
"question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation),
|
||||||
"answer": str(solution),
|
"answer": str(solution),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"equation": equation,
|
"equation": equation,
|
||||||
"variable": variable,
|
"variable": variable,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_variable(self, rng: random.Random) -> str:
|
def _get_variable(self, rng: random.Random) -> str:
|
||||||
|
|
@ -72,60 +73,60 @@ class SimpleEquationsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]:
|
def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]:
|
||||||
"""Generate an equation and its solution
|
"""Generate an equation and its solution
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rng: Random number generator
|
rng: Random number generator
|
||||||
variable: Variable symbol to use in equation
|
variable: Variable symbol to use in equation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (equation string, solution integer)
|
Tuple of (equation string, solution integer)
|
||||||
"""
|
"""
|
||||||
x = Symbol(variable)
|
x = Symbol(variable)
|
||||||
|
|
||||||
# Generate terms for left side
|
# Generate terms for left side
|
||||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
terms = []
|
terms = []
|
||||||
|
|
||||||
# Generate all constant terms first
|
# Generate all constant terms first
|
||||||
for _ in range(num_terms):
|
for _ in range(num_terms):
|
||||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
terms.append(value)
|
terms.append(value)
|
||||||
|
|
||||||
# Replace one random term with the variable term
|
# Replace one random term with the variable term
|
||||||
var_pos = rng.randint(0, num_terms - 1)
|
var_pos = rng.randint(0, num_terms - 1)
|
||||||
coef = rng.randint(self.config.min_value, self.config.max_value)
|
coef = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
terms[var_pos] = coef * x
|
terms[var_pos] = coef * x
|
||||||
|
|
||||||
# Apply operators between terms
|
# Apply operators between terms
|
||||||
expr = terms[0]
|
expr = terms[0]
|
||||||
for i in range(1, num_terms):
|
for i in range(1, num_terms):
|
||||||
op = rng.choice(self.config.operators)
|
op = rng.choice(self.config.operators)
|
||||||
if op == '+':
|
if op == "+":
|
||||||
expr = expr + terms[i]
|
expr = expr + terms[i]
|
||||||
elif op == '-':
|
elif op == "-":
|
||||||
expr = expr - terms[i]
|
expr = expr - terms[i]
|
||||||
else: # '*'
|
else: # '*'
|
||||||
expr = expr * terms[i]
|
expr = expr * terms[i]
|
||||||
|
|
||||||
left_side = expr
|
left_side = expr
|
||||||
|
|
||||||
# Generate right side
|
# Generate right side
|
||||||
right_side = rng.randint(self.config.min_value, self.config.max_value)
|
right_side = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Create equation
|
# Create equation
|
||||||
equation = Eq(left_side, right_side)
|
equation = Eq(left_side, right_side)
|
||||||
solutions = solve(equation, x)
|
solutions = solve(equation, x)
|
||||||
|
|
||||||
# Check if we found any solutions
|
# Check if we found any solutions
|
||||||
if not solutions:
|
if not solutions:
|
||||||
return self._generate_equation(rng, variable)
|
return self._generate_equation(rng, variable)
|
||||||
|
|
||||||
solution = solutions[0]
|
solution = solutions[0]
|
||||||
|
|
||||||
# Only return if solution is a positive integer
|
# Only return if solution is a positive integer
|
||||||
if not (isinstance(solution, sympy.Integer) and solution > 0):
|
if not (isinstance(solution, sympy.Integer) and solution > 0):
|
||||||
return self._generate_equation(rng, variable)
|
return self._generate_equation(rng, variable)
|
||||||
|
|
||||||
return f"{left_side} = {right_side}", int(solution)
|
return f"{left_side} = {right_side}", int(solution)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -134,7 +135,7 @@ def simple_equations_dataset(
|
||||||
max_terms: int = 5,
|
max_terms: int = 5,
|
||||||
min_value: int = 1,
|
min_value: int = 1,
|
||||||
max_value: int = 100,
|
max_value: int = 100,
|
||||||
operators: tuple = ('+', '-', '*'),
|
operators: tuple = ("+", "-", "*"),
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
size: int = 500,
|
size: int = 500,
|
||||||
) -> SimpleEquationsDataset:
|
) -> SimpleEquationsDataset:
|
||||||
|
|
|
||||||
|
|
@ -1,50 +1,44 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .simple_equations import simple_equations_dataset
|
from .simple_equations import simple_equations_dataset
|
||||||
|
|
||||||
|
|
||||||
def test_simple_equations_generation():
|
def test_simple_equations_generation():
|
||||||
dataset = simple_equations_dataset(seed=42, size=10)
|
dataset = simple_equations_dataset(seed=42, size=10)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check required keys exist
|
# Check required keys exist
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Validate answer is a string of digits
|
# Validate answer is a string of digits
|
||||||
assert item["answer"].isdigit()
|
assert item["answer"].isdigit()
|
||||||
|
|
||||||
# Validate equation format
|
# Validate equation format
|
||||||
equation = item["metadata"]["equation"]
|
equation = item["metadata"]["equation"]
|
||||||
variable = item["metadata"]["variable"]
|
variable = item["metadata"]["variable"]
|
||||||
assert "=" in equation
|
assert "=" in equation
|
||||||
assert variable in equation
|
assert variable in equation
|
||||||
|
|
||||||
# Validate question format
|
# Validate question format
|
||||||
question = item["question"]
|
question = item["question"]
|
||||||
assert variable in question
|
assert variable in question
|
||||||
assert equation in question
|
assert equation in question
|
||||||
assert any(
|
assert any(prompt in question for prompt in ["Find the value of", "Solve for", "Determine the value of"])
|
||||||
prompt in question
|
|
||||||
for prompt in [
|
|
||||||
"Find the value of",
|
|
||||||
"Solve for",
|
|
||||||
"Determine the value of"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_equations_config():
|
def test_simple_equations_config():
|
||||||
# Test invalid config raises assertion
|
# Test invalid config raises assertion
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = simple_equations_dataset(min_terms=0)
|
dataset = simple_equations_dataset(min_terms=0)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = simple_equations_dataset(max_terms=1, min_terms=2)
|
dataset = simple_equations_dataset(max_terms=1, min_terms=2)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = simple_equations_dataset(min_value=0)
|
dataset = simple_equations_dataset(min_value=0)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = simple_equations_dataset(operators=())
|
dataset = simple_equations_dataset(operators=())
|
||||||
|
|
||||||
|
|
@ -52,7 +46,7 @@ def test_simple_equations_config():
|
||||||
def test_deterministic_generation():
|
def test_deterministic_generation():
|
||||||
dataset1 = simple_equations_dataset(seed=42, size=5)
|
dataset1 = simple_equations_dataset(seed=42, size=5)
|
||||||
dataset2 = simple_equations_dataset(seed=42, size=5)
|
dataset2 = simple_equations_dataset(seed=42, size=5)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert dataset1[i]["question"] == dataset2[i]["question"]
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||||
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LetterCountingConfig:
|
class LetterCountingConfig:
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class GCDDataset(ProceduralDataset):
|
||||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||||
"""Generate a list of random positive integers and their GCD.
|
"""Generate a list of random positive integers and their GCD.
|
||||||
Will try up to 3 times to find numbers with GCD > 1."""
|
Will try up to 3 times to find numbers with GCD > 1."""
|
||||||
|
|
||||||
# Try up to 3 times to get GCD > 1
|
# Try up to 3 times to get GCD > 1
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
|
|
@ -47,7 +47,7 @@ class GCDDataset(ProceduralDataset):
|
||||||
result = reduce(gcd, numbers)
|
result = reduce(gcd, numbers)
|
||||||
if result > 1:
|
if result > 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Return the last generated numbers, whether they met the criteria or not
|
# Return the last generated numbers, whether they met the criteria or not
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class LCMDataset(ProceduralDataset):
|
||||||
result = reduce(lcm, numbers)
|
result = reduce(lcm, numbers)
|
||||||
if result < calculate_product(numbers):
|
if result < calculate_product(numbers):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Return the last generated numbers, whether they met the criteria or not
|
# Return the last generated numbers, whether they met the criteria or not
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Dict, List, Set, Tuple
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
@ -30,9 +30,9 @@ class Person:
|
||||||
name: str
|
name: str
|
||||||
gender: Gender
|
gender: Gender
|
||||||
id: int
|
id: int
|
||||||
spouse: Optional['Person'] = None
|
spouse: Optional["Person"] = None
|
||||||
parents: List['Person'] = None
|
parents: List["Person"] = None
|
||||||
children: List['Person'] = None
|
children: List["Person"] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.parents = self.parents or []
|
self.parents = self.parents or []
|
||||||
|
|
@ -46,13 +46,13 @@ class Person:
|
||||||
return False
|
return False
|
||||||
return self.id == other.id
|
return self.id == other.id
|
||||||
|
|
||||||
def add_child(self, child: 'Person'):
|
def add_child(self, child: "Person"):
|
||||||
if child not in self.children:
|
if child not in self.children:
|
||||||
self.children.append(child)
|
self.children.append(child)
|
||||||
if self not in child.parents:
|
if self not in child.parents:
|
||||||
child.parents.append(self)
|
child.parents.append(self)
|
||||||
|
|
||||||
def add_spouse(self, spouse: 'Person'):
|
def add_spouse(self, spouse: "Person"):
|
||||||
self.spouse = spouse
|
self.spouse = spouse
|
||||||
spouse.spouse = self
|
spouse.spouse = self
|
||||||
|
|
||||||
|
|
@ -60,6 +60,7 @@ class Person:
|
||||||
@dataclass
|
@dataclass
|
||||||
class FamilyRelationshipsConfig:
|
class FamilyRelationshipsConfig:
|
||||||
"""Configuration for family relationship task generation"""
|
"""Configuration for family relationship task generation"""
|
||||||
|
|
||||||
min_family_size: int = 4
|
min_family_size: int = 4
|
||||||
max_family_size: int = 8
|
max_family_size: int = 8
|
||||||
male_names: List[str] = None
|
male_names: List[str] = None
|
||||||
|
|
@ -70,22 +71,96 @@ class FamilyRelationshipsConfig:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Default name lists if none provided
|
# Default name lists if none provided
|
||||||
default_male_names = [
|
default_male_names = [
|
||||||
"James", "John", "Robert", "Michael", "William", "David", "Richard",
|
"James",
|
||||||
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew",
|
"John",
|
||||||
"Christopher", "Andrew", "George", "Edward", "Benjamin", "Henry",
|
"Robert",
|
||||||
"Samuel", "Alexander", "Oliver", "Jack", "Harry", "Jacob",
|
"Michael",
|
||||||
"Noah", "Ethan", "Lucas", "Mason", "Logan", "Sebastian", "Theodore", "Owen",
|
"William",
|
||||||
"Liam", "Aiden", "Kai", "Jayden", "Zion", "Phoenix", "Atlas", "Axel", "Ryder", "Finn"
|
"David",
|
||||||
|
"Richard",
|
||||||
|
"Joseph",
|
||||||
|
"Thomas",
|
||||||
|
"Charles",
|
||||||
|
"Peter",
|
||||||
|
"Daniel",
|
||||||
|
"Matthew",
|
||||||
|
"Christopher",
|
||||||
|
"Andrew",
|
||||||
|
"George",
|
||||||
|
"Edward",
|
||||||
|
"Benjamin",
|
||||||
|
"Henry",
|
||||||
|
"Samuel",
|
||||||
|
"Alexander",
|
||||||
|
"Oliver",
|
||||||
|
"Jack",
|
||||||
|
"Harry",
|
||||||
|
"Jacob",
|
||||||
|
"Noah",
|
||||||
|
"Ethan",
|
||||||
|
"Lucas",
|
||||||
|
"Mason",
|
||||||
|
"Logan",
|
||||||
|
"Sebastian",
|
||||||
|
"Theodore",
|
||||||
|
"Owen",
|
||||||
|
"Liam",
|
||||||
|
"Aiden",
|
||||||
|
"Kai",
|
||||||
|
"Jayden",
|
||||||
|
"Zion",
|
||||||
|
"Phoenix",
|
||||||
|
"Atlas",
|
||||||
|
"Axel",
|
||||||
|
"Ryder",
|
||||||
|
"Finn",
|
||||||
]
|
]
|
||||||
default_female_names = [
|
default_female_names = [
|
||||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
"Mary",
|
||||||
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna",
|
"Patricia",
|
||||||
"Margaret", "Victoria", "Charlotte", "Sophia", "Isabella", "Olivia",
|
"Jennifer",
|
||||||
"Ava", "Mia", "Emily", "Abigail", "Amelia", "Eleanor", "Grace",
|
"Linda",
|
||||||
"Alice", "Lucy", "Chloe", "Sophie", "Lily", "Hannah", "Zoe",
|
"Elizabeth",
|
||||||
"Luna", "Nova", "Aria", "Willow", "Aurora", "Sage", "River", "Winter", "Sky", "Rain"
|
"Barbara",
|
||||||
|
"Susan",
|
||||||
|
"Jessica",
|
||||||
|
"Sarah",
|
||||||
|
"Karen",
|
||||||
|
"Emma",
|
||||||
|
"Lisa",
|
||||||
|
"Anna",
|
||||||
|
"Margaret",
|
||||||
|
"Victoria",
|
||||||
|
"Charlotte",
|
||||||
|
"Sophia",
|
||||||
|
"Isabella",
|
||||||
|
"Olivia",
|
||||||
|
"Ava",
|
||||||
|
"Mia",
|
||||||
|
"Emily",
|
||||||
|
"Abigail",
|
||||||
|
"Amelia",
|
||||||
|
"Eleanor",
|
||||||
|
"Grace",
|
||||||
|
"Alice",
|
||||||
|
"Lucy",
|
||||||
|
"Chloe",
|
||||||
|
"Sophie",
|
||||||
|
"Lily",
|
||||||
|
"Hannah",
|
||||||
|
"Zoe",
|
||||||
|
"Luna",
|
||||||
|
"Nova",
|
||||||
|
"Aria",
|
||||||
|
"Willow",
|
||||||
|
"Aurora",
|
||||||
|
"Sage",
|
||||||
|
"River",
|
||||||
|
"Winter",
|
||||||
|
"Sky",
|
||||||
|
"Rain",
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.male_names is None:
|
if self.male_names is None:
|
||||||
self.male_names = default_male_names
|
self.male_names = default_male_names
|
||||||
if self.female_names is None:
|
if self.female_names is None:
|
||||||
|
|
@ -114,22 +189,19 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate family tree
|
# Generate family tree
|
||||||
family = self._generate_family(rng)
|
family = self._generate_family(rng)
|
||||||
|
|
||||||
# Select two people and their relationship
|
# Select two people and their relationship
|
||||||
person1, person2, relationship = self._get_relationship_question(rng, family)
|
person1, person2, relationship = self._get_relationship_question(rng, family)
|
||||||
|
|
||||||
# Generate story describing the family relationships
|
# Generate story describing the family relationships
|
||||||
story = self._generate_story(family)
|
story = self._generate_story(family)
|
||||||
|
|
||||||
# Format question
|
# Format question
|
||||||
question = rng.choice(self._templates).format(
|
question = rng.choice(self._templates).format(person1=person1.name, person2=person2.name)
|
||||||
person1=person1.name,
|
|
||||||
person2=person2.name
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"{story}\n\n{question}",
|
"question": f"{story}\n\n{question}",
|
||||||
"answer": relationship.value,
|
"answer": relationship.value,
|
||||||
|
|
@ -137,8 +209,8 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
"person1": person1.name,
|
"person1": person1.name,
|
||||||
"person2": person2.name,
|
"person2": person2.name,
|
||||||
"relationship": relationship.value,
|
"relationship": relationship.value,
|
||||||
"family_size": len(family)
|
"family_size": len(family),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate_family(self, rng: random.Random) -> Set[Person]:
|
def _generate_family(self, rng: random.Random) -> Set[Person]:
|
||||||
|
|
@ -148,8 +220,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
used_names = set()
|
used_names = set()
|
||||||
|
|
||||||
def get_name(gender: Gender) -> str:
|
def get_name(gender: Gender) -> str:
|
||||||
names = (self.config.male_names if gender == Gender.MALE
|
names = self.config.male_names if gender == Gender.MALE else self.config.female_names
|
||||||
else self.config.female_names)
|
|
||||||
available = [n for n in names if n not in used_names]
|
available = [n for n in names if n not in used_names]
|
||||||
if not available:
|
if not available:
|
||||||
return None
|
return None
|
||||||
|
|
@ -159,7 +230,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
|
|
||||||
# Create ID counter
|
# Create ID counter
|
||||||
id_counter = count()
|
id_counter = count()
|
||||||
|
|
||||||
# Create grandparents generation
|
# Create grandparents generation
|
||||||
grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter))
|
grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter))
|
||||||
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter))
|
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter))
|
||||||
|
|
@ -192,62 +263,52 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
) -> Tuple[Person, Person, Relationship]:
|
) -> Tuple[Person, Person, Relationship]:
|
||||||
"""Select two family members and determine their relationship"""
|
"""Select two family members and determine their relationship"""
|
||||||
person1, person2 = rng.sample(list(family), 2)
|
person1, person2 = rng.sample(list(family), 2)
|
||||||
|
|
||||||
# Determine relationship
|
# Determine relationship
|
||||||
if person1 in person2.parents:
|
if person1 in person2.parents:
|
||||||
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
|
relationship = Relationship.MOTHER if person1.gender == Gender.FEMALE else Relationship.FATHER
|
||||||
else Relationship.FATHER)
|
|
||||||
elif person2 in person1.parents:
|
elif person2 in person1.parents:
|
||||||
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
|
relationship = Relationship.DAUGHTER if person1.gender == Gender.FEMALE else Relationship.SON
|
||||||
else Relationship.SON)
|
|
||||||
elif person1.spouse == person2:
|
elif person1.spouse == person2:
|
||||||
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
|
relationship = Relationship.WIFE if person1.gender == Gender.FEMALE else Relationship.HUSBAND
|
||||||
else Relationship.HUSBAND)
|
elif person1.parents and person2.parents and set(person1.parents) == set(person2.parents):
|
||||||
elif (person1.parents and person2.parents and
|
relationship = Relationship.SISTER if person1.gender == Gender.FEMALE else Relationship.BROTHER
|
||||||
set(person1.parents) == set(person2.parents)):
|
elif person1 in [p for parent in person2.parents for p in parent.parents]:
|
||||||
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
|
relationship = Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE else Relationship.GRANDFATHER
|
||||||
else Relationship.BROTHER)
|
|
||||||
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
|
|
||||||
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.GRANDFATHER)
|
|
||||||
else:
|
else:
|
||||||
# Try again with different people
|
# Try again with different people
|
||||||
return self._get_relationship_question(rng, family)
|
return self._get_relationship_question(rng, family)
|
||||||
|
|
||||||
return person1, person2, relationship
|
return person1, person2, relationship
|
||||||
|
|
||||||
def _generate_story(self, family: Set[Person]) -> str:
|
def _generate_story(self, family: Set[Person]) -> str:
|
||||||
"""Generate a story describing the family relationships"""
|
"""Generate a story describing the family relationships"""
|
||||||
story_parts = []
|
story_parts = []
|
||||||
|
|
||||||
# Find married couples
|
# Find married couples
|
||||||
couples = set()
|
couples = set()
|
||||||
for person in family:
|
for person in family:
|
||||||
if person.spouse and (person.spouse, person) not in couples:
|
if person.spouse and (person.spouse, person) not in couples:
|
||||||
couples.add((person, person.spouse))
|
couples.add((person, person.spouse))
|
||||||
|
|
||||||
# Describe marriages and children for each couple
|
# Describe marriages and children for each couple
|
||||||
described_children = set() # Track which children have been described
|
described_children = set() # Track which children have been described
|
||||||
for person1, person2 in couples:
|
for person1, person2 in couples:
|
||||||
story_parts.append(f"{person1.name} is married to {person2.name}.")
|
story_parts.append(f"{person1.name} is married to {person2.name}.")
|
||||||
|
|
||||||
# Only describe children once per couple
|
# Only describe children once per couple
|
||||||
children = [c for c in person1.children if c not in described_children]
|
children = [c for c in person1.children if c not in described_children]
|
||||||
if children:
|
if children:
|
||||||
children_names = [c.name for c in children]
|
children_names = [c.name for c in children]
|
||||||
described_children.update(children) # Mark these children as described
|
described_children.update(children) # Mark these children as described
|
||||||
|
|
||||||
if len(children_names) == 1:
|
if len(children_names) == 1:
|
||||||
story_parts.append(
|
story_parts.append(f"They have a child called {children_names[0]}.")
|
||||||
f"They have a child called {children_names[0]}."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
*first, last = children_names
|
*first, last = children_names
|
||||||
children_str = ", ".join(first) + f" and {last}"
|
children_str = ", ".join(first) + f" and {last}"
|
||||||
story_parts.append(
|
story_parts.append(f"They have children called {children_str}.")
|
||||||
f"They have children called {children_str}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return " ".join(story_parts)
|
return " ".join(story_parts)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,242 +0,0 @@
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Dict, List, Set, Tuple
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
|
||||||
|
|
||||||
|
|
||||||
class Gender(Enum):
|
|
||||||
MALE = "male"
|
|
||||||
FEMALE = "female"
|
|
||||||
|
|
||||||
|
|
||||||
class Relationship(Enum):
|
|
||||||
MOTHER = "Mother"
|
|
||||||
FATHER = "Father"
|
|
||||||
SISTER = "Sister"
|
|
||||||
BROTHER = "Brother"
|
|
||||||
DAUGHTER = "Daughter"
|
|
||||||
SON = "Son"
|
|
||||||
WIFE = "Wife"
|
|
||||||
HUSBAND = "Husband"
|
|
||||||
GRANDMOTHER = "Grandmother"
|
|
||||||
GRANDFATHER = "Grandfather"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Person:
|
|
||||||
name: str
|
|
||||||
gender: Gender
|
|
||||||
spouse: Optional['Person'] = None
|
|
||||||
parents: List['Person'] = None
|
|
||||||
children: List['Person'] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.parents = self.parents or []
|
|
||||||
self.children = self.children or []
|
|
||||||
|
|
||||||
def add_child(self, child: 'Person'):
|
|
||||||
if child not in self.children:
|
|
||||||
self.children.append(child)
|
|
||||||
if self not in child.parents:
|
|
||||||
child.parents.append(self)
|
|
||||||
|
|
||||||
def add_spouse(self, spouse: 'Person'):
|
|
||||||
self.spouse = spouse
|
|
||||||
spouse.spouse = self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FamilyRelationshipsConfig:
|
|
||||||
"""Configuration for family relationship task generation"""
|
|
||||||
min_family_size: int = 4
|
|
||||||
max_family_size: int = 8
|
|
||||||
male_names: List[str] = None
|
|
||||||
female_names: List[str] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
size: int = 500
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Default name lists if none provided
|
|
||||||
self.male_names = self.male_names or [
|
|
||||||
"James", "John", "Robert", "Michael", "William", "David", "Richard",
|
|
||||||
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew"
|
|
||||||
]
|
|
||||||
self.female_names = self.female_names or [
|
|
||||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
|
||||||
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna"
|
|
||||||
]
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
"""Validate configuration parameters"""
|
|
||||||
assert self.min_family_size >= 3, "min_family_size must be at least 3"
|
|
||||||
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
|
|
||||||
assert len(self.male_names) > 0, "must provide male names"
|
|
||||||
assert len(self.female_names) > 0, "must provide female names"
|
|
||||||
|
|
||||||
|
|
||||||
class FamilyRelationshipsDataset(ProceduralDataset):
|
|
||||||
"""Generates family relationship reasoning tasks"""
|
|
||||||
|
|
||||||
def __init__(self, config: FamilyRelationshipsConfig):
|
|
||||||
self.config = config
|
|
||||||
self.config.validate()
|
|
||||||
self._templates = [
|
|
||||||
"What is {person1} to {person2}?",
|
|
||||||
"How is {person1} related to {person2}?",
|
|
||||||
"What relation is {person1} to {person2}?",
|
|
||||||
]
|
|
||||||
super().__init__(seed=config.seed, size=config.size)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
|
||||||
rng = random.Random(self.seed + idx)
|
|
||||||
|
|
||||||
# Generate family tree
|
|
||||||
family = self._generate_family(rng)
|
|
||||||
|
|
||||||
# Select two people and their relationship
|
|
||||||
person1, person2, relationship = self._get_relationship_question(rng, family)
|
|
||||||
|
|
||||||
# Generate story describing the family relationships
|
|
||||||
story = self._generate_story(family)
|
|
||||||
|
|
||||||
# Format question
|
|
||||||
question = rng.choice(self._templates).format(
|
|
||||||
person1=person1.name,
|
|
||||||
person2=person2.name
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"question": f"{story}\n\n{question}",
|
|
||||||
"answer": relationship.value,
|
|
||||||
"metadata": {
|
|
||||||
"person1": person1.name,
|
|
||||||
"person2": person2.name,
|
|
||||||
"relationship": relationship.value,
|
|
||||||
"family_size": len(family)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _generate_family(self, rng: random.Random) -> Set[Person]:
|
|
||||||
"""Generate a random family tree"""
|
|
||||||
family_size = rng.randint(self.config.min_family_size, self.config.max_family_size)
|
|
||||||
family = set()
|
|
||||||
used_names = set()
|
|
||||||
|
|
||||||
def get_name(gender: Gender) -> str:
|
|
||||||
names = (self.config.male_names if gender == Gender.MALE
|
|
||||||
else self.config.female_names)
|
|
||||||
available = [n for n in names if n not in used_names]
|
|
||||||
if not available:
|
|
||||||
return None
|
|
||||||
name = rng.choice(available)
|
|
||||||
used_names.add(name)
|
|
||||||
return name
|
|
||||||
|
|
||||||
# Create grandparents generation
|
|
||||||
grandfather = Person(get_name(Gender.MALE), Gender.MALE)
|
|
||||||
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
|
|
||||||
grandfather.add_spouse(grandmother)
|
|
||||||
family.update([grandfather, grandmother])
|
|
||||||
|
|
||||||
# Create parents
|
|
||||||
father = Person(get_name(Gender.MALE), Gender.MALE)
|
|
||||||
mother = Person(get_name(Gender.FEMALE), Gender.FEMALE)
|
|
||||||
father.add_spouse(mother)
|
|
||||||
grandfather.add_child(father)
|
|
||||||
grandmother.add_child(father)
|
|
||||||
family.update([father, mother])
|
|
||||||
|
|
||||||
# Add children
|
|
||||||
while len(family) < family_size:
|
|
||||||
gender = rng.choice([Gender.MALE, Gender.FEMALE])
|
|
||||||
name = get_name(gender)
|
|
||||||
if not name:
|
|
||||||
break
|
|
||||||
child = Person(name, gender)
|
|
||||||
father.add_child(child)
|
|
||||||
mother.add_child(child)
|
|
||||||
family.add(child)
|
|
||||||
|
|
||||||
return family
|
|
||||||
|
|
||||||
def _get_relationship_question(
|
|
||||||
self, rng: random.Random, family: Set[Person]
|
|
||||||
) -> Tuple[Person, Person, Relationship]:
|
|
||||||
"""Select two family members and determine their relationship"""
|
|
||||||
person1, person2 = rng.sample(list(family), 2)
|
|
||||||
|
|
||||||
# Determine relationship
|
|
||||||
if person1 in person2.parents:
|
|
||||||
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.FATHER)
|
|
||||||
elif person2 in person1.parents:
|
|
||||||
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.SON)
|
|
||||||
elif person1.spouse == person2:
|
|
||||||
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.HUSBAND)
|
|
||||||
elif (person1.parents and person2.parents and
|
|
||||||
set(person1.parents) == set(person2.parents)):
|
|
||||||
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.BROTHER)
|
|
||||||
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
|
|
||||||
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
|
|
||||||
else Relationship.GRANDFATHER)
|
|
||||||
else:
|
|
||||||
# Try again with different people
|
|
||||||
return self._get_relationship_question(rng, family)
|
|
||||||
|
|
||||||
return person1, person2, relationship
|
|
||||||
|
|
||||||
def _generate_story(self, family: Set[Person]) -> str:
|
|
||||||
"""Generate a story describing the family relationships"""
|
|
||||||
story_parts = []
|
|
||||||
|
|
||||||
# Find married couples
|
|
||||||
couples = set()
|
|
||||||
for person in family:
|
|
||||||
if person.spouse and (person.spouse, person) not in couples:
|
|
||||||
couples.add((person, person.spouse))
|
|
||||||
|
|
||||||
# Describe marriages
|
|
||||||
for person1, person2 in couples:
|
|
||||||
story_parts.append(f"{person1.name} is married to {person2.name}.")
|
|
||||||
|
|
||||||
# Describe parent-child relationships
|
|
||||||
for person in family:
|
|
||||||
if person.children:
|
|
||||||
children_names = [c.name for c in person.children]
|
|
||||||
if len(children_names) == 1:
|
|
||||||
story_parts.append(
|
|
||||||
f"They have a child called {children_names[0]}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
*first, last = children_names
|
|
||||||
children_str = ", ".join(first) + f" and {last}"
|
|
||||||
story_parts.append(
|
|
||||||
f"They have children called {children_str}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return " ".join(story_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def family_relationships_dataset(
|
|
||||||
min_family_size: int = 4,
|
|
||||||
max_family_size: int = 8,
|
|
||||||
male_names: List[str] = None,
|
|
||||||
female_names: List[str] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> FamilyRelationshipsDataset:
|
|
||||||
"""Create a FamilyRelationshipsDataset with the given configuration"""
|
|
||||||
config = FamilyRelationshipsConfig(
|
|
||||||
min_family_size=min_family_size,
|
|
||||||
max_family_size=max_family_size,
|
|
||||||
male_names=male_names,
|
|
||||||
female_names=female_names,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return FamilyRelationshipsDataset(config)
|
|
||||||
|
|
@ -1,35 +1,26 @@
|
||||||
from reasoning_gym.graphs.family_relationships import (
|
import pytest
|
||||||
family_relationships_dataset,
|
|
||||||
Gender,
|
from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset
|
||||||
Relationship,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_family_relationships_generation():
|
def test_family_relationships_generation():
|
||||||
dataset = family_relationships_dataset(seed=42, size=10)
|
dataset = family_relationships_dataset(seed=42, size=10)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check required keys exist
|
# Check required keys exist
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Validate story and question format
|
# Validate story and question format
|
||||||
story_and_question = item["question"]
|
story_and_question = item["question"]
|
||||||
assert "is married to" in story_and_question
|
assert "is married to" in story_and_question
|
||||||
assert "have" in story_and_question
|
assert "have" in story_and_question
|
||||||
assert any(
|
assert any(prompt in story_and_question for prompt in ["What is", "How is", "What relation is"])
|
||||||
prompt in story_and_question
|
|
||||||
for prompt in [
|
|
||||||
"What is",
|
|
||||||
"How is",
|
|
||||||
"What relation is"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate answer is a valid relationship
|
# Validate answer is a valid relationship
|
||||||
assert item["answer"] in [r.value for r in Relationship]
|
assert item["answer"] in [r.value for r in Relationship]
|
||||||
|
|
||||||
# Validate metadata
|
# Validate metadata
|
||||||
assert "person1" in item["metadata"]
|
assert "person1" in item["metadata"]
|
||||||
assert "person2" in item["metadata"]
|
assert "person2" in item["metadata"]
|
||||||
|
|
@ -42,13 +33,13 @@ def test_family_relationships_config():
|
||||||
# Test invalid config raises assertion
|
# Test invalid config raises assertion
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(min_family_size=2)
|
dataset = family_relationships_dataset(min_family_size=2)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(max_family_size=3, min_family_size=4)
|
dataset = family_relationships_dataset(max_family_size=3, min_family_size=4)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(male_names=[])
|
dataset = family_relationships_dataset(male_names=[])
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(female_names=[])
|
dataset = family_relationships_dataset(female_names=[])
|
||||||
|
|
||||||
|
|
@ -56,7 +47,7 @@ def test_family_relationships_config():
|
||||||
def test_deterministic_generation():
|
def test_deterministic_generation():
|
||||||
dataset1 = family_relationships_dataset(seed=42, size=5)
|
dataset1 = family_relationships_dataset(seed=42, size=5)
|
||||||
dataset2 = family_relationships_dataset(seed=42, size=5)
|
dataset2 = family_relationships_dataset(seed=42, size=5)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert dataset1[i]["question"] == dataset2[i]["question"]
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||||
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
||||||
|
|
@ -64,15 +55,23 @@ def test_deterministic_generation():
|
||||||
|
|
||||||
def test_relationship_consistency():
|
def test_relationship_consistency():
|
||||||
dataset = family_relationships_dataset(seed=42, size=10)
|
dataset = family_relationships_dataset(seed=42, size=10)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check that relationship matches the gender
|
# Check that relationship matches the gender
|
||||||
relationship = item["metadata"]["relationship"]
|
relationship = item["metadata"]["relationship"]
|
||||||
if relationship in [Relationship.MOTHER.value, Relationship.GRANDMOTHER.value,
|
if relationship in [
|
||||||
Relationship.WIFE.value, Relationship.SISTER.value,
|
Relationship.MOTHER.value,
|
||||||
Relationship.DAUGHTER.value]:
|
Relationship.GRANDMOTHER.value,
|
||||||
|
Relationship.WIFE.value,
|
||||||
|
Relationship.SISTER.value,
|
||||||
|
Relationship.DAUGHTER.value,
|
||||||
|
]:
|
||||||
assert "married to" in item["question"] or "child" in item["question"]
|
assert "married to" in item["question"] or "child" in item["question"]
|
||||||
elif relationship in [Relationship.FATHER.value, Relationship.GRANDFATHER.value,
|
elif relationship in [
|
||||||
Relationship.HUSBAND.value, Relationship.BROTHER.value,
|
Relationship.FATHER.value,
|
||||||
Relationship.SON.value]:
|
Relationship.GRANDFATHER.value,
|
||||||
|
Relationship.HUSBAND.value,
|
||||||
|
Relationship.BROTHER.value,
|
||||||
|
Relationship.SON.value,
|
||||||
|
]:
|
||||||
assert "married to" in item["question"] or "child" in item["question"]
|
assert "married to" in item["question"] or "child" in item["question"]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.cognition.number_sequences import Operation, PatternRule, NumberSequenceConfig, NumberSequenceDataset
|
from reasoning_gym.cognition.number_sequences import NumberSequenceConfig, NumberSequenceDataset, Operation, PatternRule
|
||||||
|
|
||||||
|
|
||||||
def test_sequence_config_validation():
|
def test_sequence_config_validation():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue