diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 18ebd713..9a6ba7fa 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,8 +1,8 @@ import random -import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Optional, Dict, List, Set, Tuple from enum import Enum +from itertools import count from ..dataset import ProceduralDataset @@ -29,22 +29,22 @@ class Relationship(Enum): class Person: name: str gender: Gender + id: int spouse: Optional['Person'] = None parents: List['Person'] = None children: List['Person'] = None - _id: uuid.UUID = field(default_factory=uuid.uuid4, compare=False) def __post_init__(self): self.parents = self.parents or [] self.children = self.children or [] def __hash__(self): - return hash(self._id) + return self.id def __eq__(self, other): if not isinstance(other, Person): return False - return self._id == other._id + return self.id == other.id def add_child(self, child: 'Person'): if child not in self.children: @@ -149,15 +149,18 @@ class FamilyRelationshipsDataset(ProceduralDataset): used_names.add(name) return name + # Create ID counter + id_counter = count() + # Create grandparents generation - grandfather = Person(get_name(Gender.MALE), Gender.MALE) - grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE) + grandfather = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) + grandmother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) grandfather.add_spouse(grandmother) family.update([grandfather, grandmother]) # Create parents - father = Person(get_name(Gender.MALE), Gender.MALE) - mother = Person(get_name(Gender.FEMALE), Gender.FEMALE) + father = Person(get_name(Gender.MALE), Gender.MALE, next(id_counter)) + mother = Person(get_name(Gender.FEMALE), Gender.FEMALE, next(id_counter)) father.add_spouse(mother) grandfather.add_child(father) grandmother.add_child(father) @@ -169,7 +172,7 @@ class FamilyRelationshipsDataset(ProceduralDataset): name = get_name(gender) if not name: break - child = Person(name, gender) + child = Person(name, gender, next(id_counter)) father.add_child(child) mother.add_child(child) family.add(child)