mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
post merge lint
This commit is contained in:
parent
02cfa9556a
commit
f396d3df60
6 changed files with 197 additions and 63 deletions
|
|
@ -21,11 +21,13 @@ class AdvancedGeometryConfig:
|
|||
|
||||
# Probability or list of tasks we want to generate
|
||||
# For demonstration, we have three categories:
|
||||
task_types: List[str] = field(default_factory=lambda: [
|
||||
"orthocenter",
|
||||
"incircle_radius",
|
||||
"angle_measure",
|
||||
])
|
||||
task_types: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"orthocenter",
|
||||
"incircle_radius",
|
||||
"angle_measure",
|
||||
]
|
||||
)
|
||||
|
||||
def validate(self):
|
||||
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
|
||||
|
|
|
|||
|
|
@ -65,22 +65,100 @@ class FamilyRelationshipsConfig:
|
|||
|
||||
min_family_size: int = 4
|
||||
max_family_size: int = 8
|
||||
male_names: List[str] = field(default_factory=lambda: [
|
||||
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph",
|
||||
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew",
|
||||
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver",
|
||||
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan",
|
||||
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion",
|
||||
"Phoenix", "Atlas", "Axel", "Ryder", "Finn"
|
||||
])
|
||||
female_names: List[str] = field(default_factory=lambda: [
|
||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
||||
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria",
|
||||
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily",
|
||||
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe",
|
||||
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow",
|
||||
"Aurora", "Sage", "River", "Winter", "Sky", "Rain"
|
||||
])
|
||||
male_names: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"James",
|
||||
"John",
|
||||
"Robert",
|
||||
"Michael",
|
||||
"William",
|
||||
"David",
|
||||
"Richard",
|
||||
"Joseph",
|
||||
"Thomas",
|
||||
"Charles",
|
||||
"Peter",
|
||||
"Daniel",
|
||||
"Matthew",
|
||||
"Christopher",
|
||||
"Andrew",
|
||||
"George",
|
||||
"Edward",
|
||||
"Benjamin",
|
||||
"Henry",
|
||||
"Samuel",
|
||||
"Alexander",
|
||||
"Oliver",
|
||||
"Jack",
|
||||
"Harry",
|
||||
"Jacob",
|
||||
"Noah",
|
||||
"Ethan",
|
||||
"Lucas",
|
||||
"Mason",
|
||||
"Logan",
|
||||
"Sebastian",
|
||||
"Theodore",
|
||||
"Owen",
|
||||
"Liam",
|
||||
"Aiden",
|
||||
"Kai",
|
||||
"Jayden",
|
||||
"Zion",
|
||||
"Phoenix",
|
||||
"Atlas",
|
||||
"Axel",
|
||||
"Ryder",
|
||||
"Finn",
|
||||
]
|
||||
)
|
||||
female_names: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"Mary",
|
||||
"Patricia",
|
||||
"Jennifer",
|
||||
"Linda",
|
||||
"Elizabeth",
|
||||
"Barbara",
|
||||
"Susan",
|
||||
"Jessica",
|
||||
"Sarah",
|
||||
"Karen",
|
||||
"Emma",
|
||||
"Lisa",
|
||||
"Anna",
|
||||
"Margaret",
|
||||
"Victoria",
|
||||
"Charlotte",
|
||||
"Sophia",
|
||||
"Isabella",
|
||||
"Olivia",
|
||||
"Ava",
|
||||
"Mia",
|
||||
"Emily",
|
||||
"Abigail",
|
||||
"Amelia",
|
||||
"Eleanor",
|
||||
"Grace",
|
||||
"Alice",
|
||||
"Lucy",
|
||||
"Chloe",
|
||||
"Sophie",
|
||||
"Lily",
|
||||
"Hannah",
|
||||
"Zoe",
|
||||
"Luna",
|
||||
"Nova",
|
||||
"Aria",
|
||||
"Willow",
|
||||
"Aurora",
|
||||
"Sage",
|
||||
"River",
|
||||
"Winter",
|
||||
"Sky",
|
||||
"Rain",
|
||||
]
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@ Logic tasks for training reasoning capabilities:
|
|||
- Syllogisms
|
||||
"""
|
||||
|
||||
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
|
||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
||||
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
||||
|
||||
__all__ = [
|
||||
"AliceInWonderlandConfig",
|
||||
"AliceInWonderlandDataset",
|
||||
"PropositionalLogicConfig",
|
||||
"PropositionalLogicDataset",
|
||||
"SyllogismConfig",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from random import Random
|
||||
from string import Template
|
||||
from typing import List, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Defines the type of task for the Alice in Wonderland dataset."""
|
||||
|
||||
SIBLINGS = "siblings"
|
||||
FRIENDS = "friends"
|
||||
COLLEAGUES = "colleagues" # Added colleagues task
|
||||
|
|
@ -26,21 +27,39 @@ class AliceInWonderlandConfig:
|
|||
size (int): Number of samples in the dataset.
|
||||
max_entities (int): Max number of siblings/friends/colleagues in questions.
|
||||
"""
|
||||
|
||||
male_names: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"James", "John", "Robert", "Michael", "William", "David",
|
||||
"Richard", "Joseph", "Thomas", "Charles", "Bob"
|
||||
"James",
|
||||
"John",
|
||||
"Robert",
|
||||
"Michael",
|
||||
"William",
|
||||
"David",
|
||||
"Richard",
|
||||
"Joseph",
|
||||
"Thomas",
|
||||
"Charles",
|
||||
"Bob",
|
||||
]
|
||||
)
|
||||
female_names: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth",
|
||||
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice"
|
||||
"Mary",
|
||||
"Patricia",
|
||||
"Jennifer",
|
||||
"Linda",
|
||||
"Elizabeth",
|
||||
"Barbara",
|
||||
"Susan",
|
||||
"Jessica",
|
||||
"Sarah",
|
||||
"Margaret",
|
||||
"Alice",
|
||||
]
|
||||
)
|
||||
task_types: List[TaskType] = field(
|
||||
default_factory=lambda: [
|
||||
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
||||
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 10
|
||||
|
|
@ -57,19 +76,19 @@ class AliceInWonderlandConfig:
|
|||
|
||||
class AliceInWonderlandDataset(ProceduralDataset):
|
||||
"""
|
||||
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
||||
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
||||
|
||||
The dataset is inspired by the following paper:
|
||||
@inproceedings{nezhurina2024alice,
|
||||
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
||||
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
||||
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
||||
Jenia Jitsev},
|
||||
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
||||
Deep Learning},
|
||||
year={2024},
|
||||
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
||||
}
|
||||
The dataset is inspired by the following paper:
|
||||
@inproceedings{nezhurina2024alice,
|
||||
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
||||
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
||||
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
||||
Jenia Jitsev},
|
||||
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
||||
Deep Learning},
|
||||
year={2024},
|
||||
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -152,14 +171,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
num_female=num_female,
|
||||
)
|
||||
elif task_type == TaskType.COLLEAGUES:
|
||||
num_male_colleagues_alice_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_female_colleagues_alice_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_male_colleagues_bob_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_female_colleagues_bob_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||
|
||||
answer = num_female_colleagues_alice_circle + 1
|
||||
template = rng.choice(self.templates[TaskType.COLLEAGUES])
|
||||
|
|
@ -169,16 +184,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
|
||||
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
|
||||
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
|
||||
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
|
||||
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
|
||||
)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"task_type": task_type.value
|
||||
}
|
||||
}
|
||||
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = Random(self.seed + idx)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue