mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
refactor: Use field default_factory TimeIntervalsConfig, AdvancedGeometryConfig
This commit is contained in:
parent
8202f234be
commit
4e9fc4baad
3 changed files with 29 additions and 124 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from datetime import date, datetime, time, timedelta
|
from datetime import date, datetime, time, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
@ -19,14 +19,12 @@ class TimeIntervalsConfig:
|
||||||
min_date: date = date(1900, 1, 1)
|
min_date: date = date(1900, 1, 1)
|
||||||
max_date: date = date(3000, 1, 1)
|
max_date: date = date(3000, 1, 1)
|
||||||
max_date_difference_days: int = 100
|
max_date_difference_days: int = 100
|
||||||
task_types: List[str] = None
|
task_types: List[str] = field(
|
||||||
|
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
|
||||||
|
)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.task_types is None:
|
|
||||||
self.task_types = ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
assert self.size > 0, "size must be positive"
|
assert self.size > 0, "size must be positive"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -21,16 +21,11 @@ class AdvancedGeometryConfig:
|
||||||
|
|
||||||
# Probability or list of tasks we want to generate
|
# Probability or list of tasks we want to generate
|
||||||
# For demonstration, we have three categories:
|
# For demonstration, we have three categories:
|
||||||
task_types: List[str] = None
|
task_types: List[str] = field(default_factory=lambda: [
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.task_types is None:
|
|
||||||
# Default set of advanced tasks
|
|
||||||
self.task_types = [
|
|
||||||
"orthocenter",
|
"orthocenter",
|
||||||
"incircle_radius",
|
"incircle_radius",
|
||||||
"angle_measure",
|
"angle_measure",
|
||||||
]
|
])
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
|
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
@ -37,12 +37,8 @@ class Person:
|
||||||
gender: Gender
|
gender: Gender
|
||||||
id: int
|
id: int
|
||||||
spouse: Optional["Person"] = None
|
spouse: Optional["Person"] = None
|
||||||
parents: List["Person"] = None
|
parents: List["Person"] = field(default_factory=list)
|
||||||
children: List["Person"] = None
|
children: List["Person"] = field(default_factory=list)
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.parents = self.parents or []
|
|
||||||
self.children = self.children or []
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return self.id
|
return self.id
|
||||||
|
|
@ -69,109 +65,25 @@ class FamilyRelationshipsConfig:
|
||||||
|
|
||||||
min_family_size: int = 4
|
min_family_size: int = 4
|
||||||
max_family_size: int = 8
|
max_family_size: int = 8
|
||||||
male_names: List[str] = None
|
male_names: List[str] = field(default_factory=lambda: [
|
||||||
female_names: List[str] = None
|
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph",
|
||||||
|
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew",
|
||||||
|
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver",
|
||||||
|
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan",
|
||||||
|
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion",
|
||||||
|
"Phoenix", "Atlas", "Axel", "Ryder", "Finn"
|
||||||
|
])
|
||||||
|
female_names: List[str] = field(default_factory=lambda: [
|
||||||
|
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
||||||
|
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria",
|
||||||
|
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily",
|
||||||
|
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe",
|
||||||
|
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow",
|
||||||
|
"Aurora", "Sage", "River", "Winter", "Sky", "Rain"
|
||||||
|
])
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Default name lists if none provided
|
|
||||||
default_male_names = [
|
|
||||||
"James",
|
|
||||||
"John",
|
|
||||||
"Robert",
|
|
||||||
"Michael",
|
|
||||||
"William",
|
|
||||||
"David",
|
|
||||||
"Richard",
|
|
||||||
"Joseph",
|
|
||||||
"Thomas",
|
|
||||||
"Charles",
|
|
||||||
"Peter",
|
|
||||||
"Daniel",
|
|
||||||
"Matthew",
|
|
||||||
"Christopher",
|
|
||||||
"Andrew",
|
|
||||||
"George",
|
|
||||||
"Edward",
|
|
||||||
"Benjamin",
|
|
||||||
"Henry",
|
|
||||||
"Samuel",
|
|
||||||
"Alexander",
|
|
||||||
"Oliver",
|
|
||||||
"Jack",
|
|
||||||
"Harry",
|
|
||||||
"Jacob",
|
|
||||||
"Noah",
|
|
||||||
"Ethan",
|
|
||||||
"Lucas",
|
|
||||||
"Mason",
|
|
||||||
"Logan",
|
|
||||||
"Sebastian",
|
|
||||||
"Theodore",
|
|
||||||
"Owen",
|
|
||||||
"Liam",
|
|
||||||
"Aiden",
|
|
||||||
"Kai",
|
|
||||||
"Jayden",
|
|
||||||
"Zion",
|
|
||||||
"Phoenix",
|
|
||||||
"Atlas",
|
|
||||||
"Axel",
|
|
||||||
"Ryder",
|
|
||||||
"Finn",
|
|
||||||
]
|
|
||||||
default_female_names = [
|
|
||||||
"Mary",
|
|
||||||
"Patricia",
|
|
||||||
"Jennifer",
|
|
||||||
"Linda",
|
|
||||||
"Elizabeth",
|
|
||||||
"Barbara",
|
|
||||||
"Susan",
|
|
||||||
"Jessica",
|
|
||||||
"Sarah",
|
|
||||||
"Karen",
|
|
||||||
"Emma",
|
|
||||||
"Lisa",
|
|
||||||
"Anna",
|
|
||||||
"Margaret",
|
|
||||||
"Victoria",
|
|
||||||
"Charlotte",
|
|
||||||
"Sophia",
|
|
||||||
"Isabella",
|
|
||||||
"Olivia",
|
|
||||||
"Ava",
|
|
||||||
"Mia",
|
|
||||||
"Emily",
|
|
||||||
"Abigail",
|
|
||||||
"Amelia",
|
|
||||||
"Eleanor",
|
|
||||||
"Grace",
|
|
||||||
"Alice",
|
|
||||||
"Lucy",
|
|
||||||
"Chloe",
|
|
||||||
"Sophie",
|
|
||||||
"Lily",
|
|
||||||
"Hannah",
|
|
||||||
"Zoe",
|
|
||||||
"Luna",
|
|
||||||
"Nova",
|
|
||||||
"Aria",
|
|
||||||
"Willow",
|
|
||||||
"Aurora",
|
|
||||||
"Sage",
|
|
||||||
"River",
|
|
||||||
"Winter",
|
|
||||||
"Sky",
|
|
||||||
"Rain",
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.male_names is None:
|
|
||||||
self.male_names = default_male_names
|
|
||||||
if self.female_names is None:
|
|
||||||
self.female_names = default_female_names
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
assert self.min_family_size >= 3, "min_family_size must be at least 3"
|
assert self.min_family_size >= 3, "min_family_size must be at least 3"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue