mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
formatting, cleanup
This commit is contained in:
parent
b767e58e48
commit
3dc80be7d2
12 changed files with 189 additions and 376 deletions
11
reasoning_gym/graphs/__init__.py
Normal file
11
reasoning_gym/graphs/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from .family_relationships import (
|
||||
FamilyRelationshipsDataset,
|
||||
FamilyRelationshipsConfig,
|
||||
family_relationships_dataset
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FamilyRelationshipsDataset",
|
||||
"FamilyRelationshipsConfig",
|
||||
"family_relationships_dataset"
|
||||
]
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, List, Set, Tuple
|
||||
from enum import Enum
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
|
@ -30,9 +30,9 @@ class Person:
|
|||
name: str
|
||||
gender: Gender
|
||||
id: int
|
||||
spouse: Optional['Person'] = None
|
||||
parents: List['Person'] = None
|
||||
children: List['Person'] = None
|
||||
spouse: Optional["Person"] = None
|
||||
parents: List["Person"] = None
|
||||
children: List["Person"] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.parents = self.parents or []
|
||||
|
|
@ -46,13 +46,13 @@ class Person:
|
|||
return False
|
||||
return self.id == other.id
|
||||
|
||||
def add_child(self, child: 'Person'):
|
||||
def add_child(self, child: "Person"):
|
||||
if child not in self.children:
|
||||
self.children.append(child)
|
||||
if self not in child.parents:
|
||||
child.parents.append(self)
|
||||
|
||||
def add_spouse(self, spouse: 'Person'):
|
||||
def add_spouse(self, spouse: "Person"):
|
||||
self.spouse = spouse
|
||||
spouse.spouse = self
|
||||
|
||||
|
|
@ -60,6 +60,7 @@ class Person:
|
|||
@dataclass
|
||||
class FamilyRelationshipsConfig:
|
||||
"""Configuration for family relationship task generation"""
|
||||
|
||||
min_family_size: int = 4
|
||||
max_family_size: int = 8
|
||||
male_names: List[str] = None
|
||||
|
|
@ -70,22 +71,96 @@ class FamilyRelationshipsConfig:
|
|||
def __post_init__(self):
|
||||
# Default name lists if none provided
|
||||
default_male_names = [
|
||||
"James", "John", "Robert", "Michael", "William", "David", "Richard",
|
||||
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew",
|
||||
"Christopher", "Andrew", "George", "Edward", "Benjamin", "Henry",
|
||||
"Samuel", "Alexander", "Oliver", "Jack", "Harry", "Jacob",
|
||||
"Noah", "Ethan", "Lucas", "Mason", "Logan", "Sebastian", "Theodore", "Owen",
|
||||
"Liam", "Aiden", "Kai", "Jayden", "Zion", "Phoenix", "Atlas", "Axel", "Ryder", "Finn"
|
||||
"James",
|
||||
"John",
|
||||
"Robert",
|
||||
"Michael",
|
||||
"William",
|
||||
"David",
|
||||
"Richard",
|
||||
"Joseph",
|
||||
"Thomas",
|
||||
"Charles",
|
||||
"Peter",
|
||||
"Daniel",
|
||||
"Matthew",
|
||||
"Christopher",
|
||||
"Andrew",
|
||||
"George",
|
||||
"Edward",
|
||||
"Benjamin",
|
||||
"Henry",
|
||||
"Samuel",
|
||||
"Alexander",
|
||||
"Oliver",
|
||||
"Jack",
|
||||
"Harry",
|
||||
"Jacob",
|
||||
"Noah",
|
||||
"Ethan",
|
||||
"Lucas",
|
||||
"Mason",
|
||||
"Logan",
|
||||
"Sebastian",
|
||||
"Theodore",
|
||||
"Owen",
|
||||
"Liam",
|
||||
"Aiden",
|
||||
"Kai",
|
||||
"Jayden",
|
||||
"Zion",
|
||||
"Phoenix",
|
||||
"Atlas",
|
||||
"Axel",
|
||||
"Ryder",
|
||||
"Finn",
|
||||
]
|
||||
default_female_names = [
|
||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
||||
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna",
|
||||
"Margaret", "Victoria", "Charlotte", "Sophia", "Isabella", "Olivia",
|
||||
"Ava", "Mia", "Emily", "Abigail", "Amelia", "Eleanor", "Grace",
|
||||
"Alice", "Lucy", "Chloe", "Sophie", "Lily", "Hannah", "Zoe",
|
||||
"Luna", "Nova", "Aria", "Willow", "Aurora", "Sage", "River", "Winter", "Sky", "Rain"
|
||||
"Mary",
|
||||
"Patricia",
|
||||
"Jennifer",
|
||||
"Linda",
|
||||
"Elizabeth",
|
||||
"Barbara",
|
||||
"Susan",
|
||||
"Jessica",
|
||||
"Sarah",
|
||||
"Karen",
|
||||
"Emma",
|
||||
"Lisa",
|
||||
"Anna",
|
||||
"Margaret",
|
||||
"Victoria",
|
||||
"Charlotte",
|
||||
"Sophia",
|
||||
"Isabella",
|
||||
"Olivia",
|
||||
"Ava",
|
||||
"Mia",
|
||||
"Emily",
|
||||
"Abigail",
|
||||
"Amelia",
|
||||
"Eleanor",
|
||||
"Grace",
|
||||
"Alice",
|
||||
"Lucy",
|
||||
"Chloe",
|
||||
"Sophie",
|
||||
"Lily",
|
||||
"Hannah",
|
||||
"Zoe",
|
||||
"Luna",
|
||||
"Nova",
|
||||
"Aria",
|
||||
"Willow",
|
||||
"Aurora",
|
||||
"Sage",
|
||||
"River",
|
||||
"Winter",
|
||||
"Sky",
|
||||
"Rain",
|
||||
]
|
||||
|
||||
|
||||
if self.male_names is None:
|
||||
self.male_names = default_male_names
|
||||
if self.female_names is None:
|
||||
|
|
@ -114,22 +189,19 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
|
||||
# Generate family tree
|
||||
family = self._generate_family(rng)
|
||||
|
||||
|
||||
# Select two people and their relationship
|
||||
person1, person2, relationship = self._get_relationship_question(rng, family)
|
||||
|
||||
|
||||
# Generate story describing the family relationships
|
||||
story = self._generate_story(family)
|
||||
|
||||
|
||||
# Format question
|
||||
question = rng.choice(self._templates).format(
|
||||
person1=person1.name,
|
||||
person2=person2.name
|
||||
)
|
||||
|
||||
question = rng.choice(self._templates).format(person1=person1.name, person2=person2.name)
|
||||
|
||||
return {
|
||||
"question": f"{story}\n\n{question}",
|
||||
"answer": relationship.value,
|
||||
|
|
@ -137,8 +209,8 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
"person1": person1.name,
|
||||
"person2": person2.name,
|
||||
"relationship": relationship.value,
|
||||
"family_size": len(family)
|
||||
}
|
||||
"family_size": len(family),
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_family(self, rng: random.Random) -> Set[Person]:
|
||||
|
|
@ -148,8 +220,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
used_names = set()
|
||||
|
||||
def get_name(gender: Gender) -> str:
|
||||
names = (self.config.male_names if gender == Gender.MALE
|
||||
else self.config.female_names)
|
||||
names = self.config.male_names if gender == Gender.MALE else self.config.female_names
|
||||
available = [n for n in names if n not in used_names]
|
||||
if not available:
|
||||
return None
|
||||
|
|
@ -159,7 +230,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
|
||||
# Create ID counter
|
||||
id_counter = count()
|
||||
|
||||
|
||||
# Create grandparents generation
|
||||
grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter))
|
||||
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter))
|
||||
|
|
@ -192,62 +263,52 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
) -> Tuple[Person, Person, Relationship]:
|
||||
"""Select two family members and determine their relationship"""
|
||||
person1, person2 = rng.sample(list(family), 2)
|
||||
|
||||
|
||||
# Determine relationship
|
||||
if person1 in person2.parents:
|
||||
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
|
||||
else Relationship.FATHER)
|
||||
relationship = Relationship.MOTHER if person1.gender == Gender.FEMALE else Relationship.FATHER
|
||||
elif person2 in person1.parents:
|
||||
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
|
||||
else Relationship.SON)
|
||||
relationship = Relationship.DAUGHTER if person1.gender == Gender.FEMALE else Relationship.SON
|
||||
elif person1.spouse == person2:
|
||||
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
|
||||
else Relationship.HUSBAND)
|
||||
elif (person1.parents and person2.parents and
|
||||
set(person1.parents) == set(person2.parents)):
|
||||
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
|
||||
else Relationship.BROTHER)
|
||||
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
|
||||
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
|
||||
else Relationship.GRANDFATHER)
|
||||
relationship = Relationship.WIFE if person1.gender == Gender.FEMALE else Relationship.HUSBAND
|
||||
elif person1.parents and person2.parents and set(person1.parents) == set(person2.parents):
|
||||
relationship = Relationship.SISTER if person1.gender == Gender.FEMALE else Relationship.BROTHER
|
||||
elif person1 in [p for parent in person2.parents for p in parent.parents]:
|
||||
relationship = Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE else Relationship.GRANDFATHER
|
||||
else:
|
||||
# Try again with different people
|
||||
return self._get_relationship_question(rng, family)
|
||||
|
||||
|
||||
return person1, person2, relationship
|
||||
|
||||
def _generate_story(self, family: Set[Person]) -> str:
|
||||
"""Generate a story describing the family relationships"""
|
||||
story_parts = []
|
||||
|
||||
|
||||
# Find married couples
|
||||
couples = set()
|
||||
for person in family:
|
||||
if person.spouse and (person.spouse, person) not in couples:
|
||||
couples.add((person, person.spouse))
|
||||
|
||||
|
||||
# Describe marriages and children for each couple
|
||||
described_children = set() # Track which children have been described
|
||||
for person1, person2 in couples:
|
||||
story_parts.append(f"{person1.name} is married to {person2.name}.")
|
||||
|
||||
|
||||
# Only describe children once per couple
|
||||
children = [c for c in person1.children if c not in described_children]
|
||||
if children:
|
||||
children_names = [c.name for c in children]
|
||||
described_children.update(children) # Mark these children as described
|
||||
|
||||
|
||||
if len(children_names) == 1:
|
||||
story_parts.append(
|
||||
f"They have a child called {children_names[0]}."
|
||||
)
|
||||
story_parts.append(f"They have a child called {children_names[0]}.")
|
||||
else:
|
||||
*first, last = children_names
|
||||
children_str = ", ".join(first) + f" and {last}"
|
||||
story_parts.append(
|
||||
f"They have children called {children_str}."
|
||||
)
|
||||
|
||||
story_parts.append(f"They have children called {children_str}.")
|
||||
|
||||
return " ".join(story_parts)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue