formatting, cleanup

This commit is contained in:
Andreas Koepf 2025-01-24 17:12:42 +01:00
parent b767e58e48
commit 3dc80be7d2
12 changed files with 189 additions and 376 deletions

View file

@ -0,0 +1,11 @@
from .family_relationships import (
FamilyRelationshipsDataset,
FamilyRelationshipsConfig,
family_relationships_dataset
)
__all__ = [
"FamilyRelationshipsDataset",
"FamilyRelationshipsConfig",
"family_relationships_dataset"
]

View file

@ -1,8 +1,8 @@
import random
from dataclasses import dataclass
from typing import Optional, Dict, List, Set, Tuple
from enum import Enum
from itertools import count
from typing import Dict, List, Optional, Set, Tuple
from ..dataset import ProceduralDataset
@ -30,9 +30,9 @@ class Person:
name: str
gender: Gender
id: int
spouse: Optional['Person'] = None
parents: List['Person'] = None
children: List['Person'] = None
spouse: Optional["Person"] = None
parents: List["Person"] = None
children: List["Person"] = None
def __post_init__(self):
self.parents = self.parents or []
@ -46,13 +46,13 @@ class Person:
return False
return self.id == other.id
def add_child(self, child: 'Person'):
def add_child(self, child: "Person"):
if child not in self.children:
self.children.append(child)
if self not in child.parents:
child.parents.append(self)
def add_spouse(self, spouse: 'Person'):
def add_spouse(self, spouse: "Person"):
self.spouse = spouse
spouse.spouse = self
@ -60,6 +60,7 @@ class Person:
@dataclass
class FamilyRelationshipsConfig:
"""Configuration for family relationship task generation"""
min_family_size: int = 4
max_family_size: int = 8
male_names: List[str] = None
@ -70,22 +71,96 @@ class FamilyRelationshipsConfig:
def __post_init__(self):
# Default name lists if none provided
default_male_names = [
"James", "John", "Robert", "Michael", "William", "David", "Richard",
"Joseph", "Thomas", "Charles", "Peter", "Daniel", "Matthew",
"Christopher", "Andrew", "George", "Edward", "Benjamin", "Henry",
"Samuel", "Alexander", "Oliver", "Jack", "Harry", "Jacob",
"Noah", "Ethan", "Lucas", "Mason", "Logan", "Sebastian", "Theodore", "Owen",
"Liam", "Aiden", "Kai", "Jayden", "Zion", "Phoenix", "Atlas", "Axel", "Ryder", "Finn"
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Peter",
"Daniel",
"Matthew",
"Christopher",
"Andrew",
"George",
"Edward",
"Benjamin",
"Henry",
"Samuel",
"Alexander",
"Oliver",
"Jack",
"Harry",
"Jacob",
"Noah",
"Ethan",
"Lucas",
"Mason",
"Logan",
"Sebastian",
"Theodore",
"Owen",
"Liam",
"Aiden",
"Kai",
"Jayden",
"Zion",
"Phoenix",
"Atlas",
"Axel",
"Ryder",
"Finn",
]
default_female_names = [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna",
"Margaret", "Victoria", "Charlotte", "Sophia", "Isabella", "Olivia",
"Ava", "Mia", "Emily", "Abigail", "Amelia", "Eleanor", "Grace",
"Alice", "Lucy", "Chloe", "Sophie", "Lily", "Hannah", "Zoe",
"Luna", "Nova", "Aria", "Willow", "Aurora", "Sage", "River", "Winter", "Sky", "Rain"
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Emma",
"Lisa",
"Anna",
"Margaret",
"Victoria",
"Charlotte",
"Sophia",
"Isabella",
"Olivia",
"Ava",
"Mia",
"Emily",
"Abigail",
"Amelia",
"Eleanor",
"Grace",
"Alice",
"Lucy",
"Chloe",
"Sophie",
"Lily",
"Hannah",
"Zoe",
"Luna",
"Nova",
"Aria",
"Willow",
"Aurora",
"Sage",
"River",
"Winter",
"Sky",
"Rain",
]
if self.male_names is None:
self.male_names = default_male_names
if self.female_names is None:
@ -114,22 +189,19 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
# Generate family tree
family = self._generate_family(rng)
# Select two people and their relationship
person1, person2, relationship = self._get_relationship_question(rng, family)
# Generate story describing the family relationships
story = self._generate_story(family)
# Format question
question = rng.choice(self._templates).format(
person1=person1.name,
person2=person2.name
)
question = rng.choice(self._templates).format(person1=person1.name, person2=person2.name)
return {
"question": f"{story}\n\n{question}",
"answer": relationship.value,
@ -137,8 +209,8 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"person1": person1.name,
"person2": person2.name,
"relationship": relationship.value,
"family_size": len(family)
}
"family_size": len(family),
},
}
def _generate_family(self, rng: random.Random) -> Set[Person]:
@ -148,8 +220,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
used_names = set()
def get_name(gender: Gender) -> str:
names = (self.config.male_names if gender == Gender.MALE
else self.config.female_names)
names = self.config.male_names if gender == Gender.MALE else self.config.female_names
available = [n for n in names if n not in used_names]
if not available:
return None
@ -159,7 +230,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
# Create ID counter
id_counter = count()
# Create grandparents generation
grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter))
grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter))
@ -192,62 +263,52 @@ class FamilyRelationshipsDataset(ProceduralDataset):
) -> Tuple[Person, Person, Relationship]:
"""Select two family members and determine their relationship"""
person1, person2 = rng.sample(list(family), 2)
# Determine relationship
if person1 in person2.parents:
relationship = (Relationship.MOTHER if person1.gender == Gender.FEMALE
else Relationship.FATHER)
relationship = Relationship.MOTHER if person1.gender == Gender.FEMALE else Relationship.FATHER
elif person2 in person1.parents:
relationship = (Relationship.DAUGHTER if person1.gender == Gender.FEMALE
else Relationship.SON)
relationship = Relationship.DAUGHTER if person1.gender == Gender.FEMALE else Relationship.SON
elif person1.spouse == person2:
relationship = (Relationship.WIFE if person1.gender == Gender.FEMALE
else Relationship.HUSBAND)
elif (person1.parents and person2.parents and
set(person1.parents) == set(person2.parents)):
relationship = (Relationship.SISTER if person1.gender == Gender.FEMALE
else Relationship.BROTHER)
elif (person1 in [p for parent in person2.parents for p in parent.parents]):
relationship = (Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE
else Relationship.GRANDFATHER)
relationship = Relationship.WIFE if person1.gender == Gender.FEMALE else Relationship.HUSBAND
elif person1.parents and person2.parents and set(person1.parents) == set(person2.parents):
relationship = Relationship.SISTER if person1.gender == Gender.FEMALE else Relationship.BROTHER
elif person1 in [p for parent in person2.parents for p in parent.parents]:
relationship = Relationship.GRANDMOTHER if person1.gender == Gender.FEMALE else Relationship.GRANDFATHER
else:
# Try again with different people
return self._get_relationship_question(rng, family)
return person1, person2, relationship
def _generate_story(self, family: Set[Person]) -> str:
"""Generate a story describing the family relationships"""
story_parts = []
# Find married couples
couples = set()
for person in family:
if person.spouse and (person.spouse, person) not in couples:
couples.add((person, person.spouse))
# Describe marriages and children for each couple
described_children = set() # Track which children have been described
for person1, person2 in couples:
story_parts.append(f"{person1.name} is married to {person2.name}.")
# Only describe children once per couple
children = [c for c in person1.children if c not in described_children]
if children:
children_names = [c.name for c in children]
described_children.update(children) # Mark these children as described
if len(children_names) == 1:
story_parts.append(
f"They have a child called {children_names[0]}."
)
story_parts.append(f"They have a child called {children_names[0]}.")
else:
*first, last = children_names
children_str = ", ".join(first) + f" and {last}"
story_parts.append(
f"They have children called {children_str}."
)
story_parts.append(f"They have children called {children_str}.")
return " ".join(story_parts)