mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
post merge lint
This commit is contained in:
parent
02cfa9556a
commit
f396d3df60
6 changed files with 197 additions and 63 deletions
47
GALLERY.md
47
GALLERY.md
|
|
@ -3,6 +3,7 @@ This gallery shows examples from all available datasets using their default conf
|
||||||
|
|
||||||
## Available Datasets
|
## Available Datasets
|
||||||
- [advanced_geometry](#advanced_geometry)
|
- [advanced_geometry](#advanced_geometry)
|
||||||
|
- [aiw](#aiw)
|
||||||
- [base_conversion](#base_conversion)
|
- [base_conversion](#base_conversion)
|
||||||
- [basic_arithmetic](#basic_arithmetic)
|
- [basic_arithmetic](#basic_arithmetic)
|
||||||
- [bf](#bf)
|
- [bf](#bf)
|
||||||
|
|
@ -73,6 +74,50 @@ Metadata: {'A': (6, 7), 'B': (-7, -5), 'C': (2, -3), 'incircle_radius_exact': 's
|
||||||
|
|
||||||
````
|
````
|
||||||
|
|
||||||
|
### aiw
|
||||||
|
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
||||||
|
|
||||||
|
The dataset is inspired by the following paper:
|
||||||
|
@inproceedings{nezhurina2024alice,
|
||||||
|
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
||||||
|
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
||||||
|
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
||||||
|
Jenia Jitsev},
|
||||||
|
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
||||||
|
Deep Learning},
|
||||||
|
year={2024},
|
||||||
|
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
||||||
|
}
|
||||||
|
|
||||||
|
Default configuration:
|
||||||
|
```python
|
||||||
|
male_names = ['James', 'John', 'Robert', 'Michael', 'William', 'David', 'Richard', 'Joseph', 'Thomas', 'Charles', 'Bob']
|
||||||
|
female_names = ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah', 'Margaret', 'Alice']
|
||||||
|
task_types = [<TaskType.SIBLINGS: 'siblings'>, <TaskType.FRIENDS: 'friends'>, <TaskType.COLLEAGUES: 'colleagues'>]
|
||||||
|
seed = 42
|
||||||
|
size = 10
|
||||||
|
max_entities = 6
|
||||||
|
```
|
||||||
|
|
||||||
|
Example tasks:
|
||||||
|
````
|
||||||
|
Example 1:
|
||||||
|
Question: Patricia has 6 male colleagues and she also has 3 female colleagues. These are all colleagues that Patricia has. All these mentioned persons around Patricia are colleagues of each other. James has 2 male colleagues and 2 female colleagues in total. All these mentioned persons around James are colleagues of each other. The people in the circle around James do not have other colleagues aside - with the only exception of Matilda. She is colleague of James and she is also colleague of Patricia, being part of Patricia's circle. How many female colleagues does Matilda have?
|
||||||
|
Answer: 4
|
||||||
|
Metadata: {'task_type': 'colleagues'}
|
||||||
|
|
||||||
|
Example 2:
|
||||||
|
Question: Elizabeth has 4 brothers and she also has 3 sisters. How many sisters does Elizabeth's brother have?
|
||||||
|
Answer: 4
|
||||||
|
Metadata: {'task_type': 'siblings'}
|
||||||
|
|
||||||
|
Example 3:
|
||||||
|
Question: Sarah has 6 male friends and she also has 1 female friends. They all are friends with each other and have no other friends aside. How many female friends does Thomas, a male friend of Sarah, have?
|
||||||
|
Answer: 2
|
||||||
|
Metadata: {'task_type': 'friends'}
|
||||||
|
|
||||||
|
````
|
||||||
|
|
||||||
### base_conversion
|
### base_conversion
|
||||||
Generates base conversion tasks
|
Generates base conversion tasks
|
||||||
|
|
||||||
|
|
@ -1548,7 +1593,7 @@ Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6,
|
||||||
Example 2:
|
Example 2:
|
||||||
Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM.
|
Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM.
|
||||||
Answer: 02:38
|
Answer: 02:38
|
||||||
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 1, 9, 44), 'end_time': datetime.datetime(2025, 2, 1, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
|
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 2, 9, 44), 'end_time': datetime.datetime(2025, 2, 2, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
|
||||||
|
|
||||||
Example 3:
|
Example 3:
|
||||||
Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days.
|
Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days.
|
||||||
|
|
|
||||||
|
|
@ -21,11 +21,13 @@ class AdvancedGeometryConfig:
|
||||||
|
|
||||||
# Probability or list of tasks we want to generate
|
# Probability or list of tasks we want to generate
|
||||||
# For demonstration, we have three categories:
|
# For demonstration, we have three categories:
|
||||||
task_types: List[str] = field(default_factory=lambda: [
|
task_types: List[str] = field(
|
||||||
"orthocenter",
|
default_factory=lambda: [
|
||||||
"incircle_radius",
|
"orthocenter",
|
||||||
"angle_measure",
|
"incircle_radius",
|
||||||
])
|
"angle_measure",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
|
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
|
||||||
|
|
|
||||||
|
|
@ -65,22 +65,100 @@ class FamilyRelationshipsConfig:
|
||||||
|
|
||||||
min_family_size: int = 4
|
min_family_size: int = 4
|
||||||
max_family_size: int = 8
|
max_family_size: int = 8
|
||||||
male_names: List[str] = field(default_factory=lambda: [
|
male_names: List[str] = field(
|
||||||
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph",
|
default_factory=lambda: [
|
||||||
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew",
|
"James",
|
||||||
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver",
|
"John",
|
||||||
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan",
|
"Robert",
|
||||||
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion",
|
"Michael",
|
||||||
"Phoenix", "Atlas", "Axel", "Ryder", "Finn"
|
"William",
|
||||||
])
|
"David",
|
||||||
female_names: List[str] = field(default_factory=lambda: [
|
"Richard",
|
||||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
|
"Joseph",
|
||||||
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria",
|
"Thomas",
|
||||||
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily",
|
"Charles",
|
||||||
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe",
|
"Peter",
|
||||||
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow",
|
"Daniel",
|
||||||
"Aurora", "Sage", "River", "Winter", "Sky", "Rain"
|
"Matthew",
|
||||||
])
|
"Christopher",
|
||||||
|
"Andrew",
|
||||||
|
"George",
|
||||||
|
"Edward",
|
||||||
|
"Benjamin",
|
||||||
|
"Henry",
|
||||||
|
"Samuel",
|
||||||
|
"Alexander",
|
||||||
|
"Oliver",
|
||||||
|
"Jack",
|
||||||
|
"Harry",
|
||||||
|
"Jacob",
|
||||||
|
"Noah",
|
||||||
|
"Ethan",
|
||||||
|
"Lucas",
|
||||||
|
"Mason",
|
||||||
|
"Logan",
|
||||||
|
"Sebastian",
|
||||||
|
"Theodore",
|
||||||
|
"Owen",
|
||||||
|
"Liam",
|
||||||
|
"Aiden",
|
||||||
|
"Kai",
|
||||||
|
"Jayden",
|
||||||
|
"Zion",
|
||||||
|
"Phoenix",
|
||||||
|
"Atlas",
|
||||||
|
"Axel",
|
||||||
|
"Ryder",
|
||||||
|
"Finn",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
female_names: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"Mary",
|
||||||
|
"Patricia",
|
||||||
|
"Jennifer",
|
||||||
|
"Linda",
|
||||||
|
"Elizabeth",
|
||||||
|
"Barbara",
|
||||||
|
"Susan",
|
||||||
|
"Jessica",
|
||||||
|
"Sarah",
|
||||||
|
"Karen",
|
||||||
|
"Emma",
|
||||||
|
"Lisa",
|
||||||
|
"Anna",
|
||||||
|
"Margaret",
|
||||||
|
"Victoria",
|
||||||
|
"Charlotte",
|
||||||
|
"Sophia",
|
||||||
|
"Isabella",
|
||||||
|
"Olivia",
|
||||||
|
"Ava",
|
||||||
|
"Mia",
|
||||||
|
"Emily",
|
||||||
|
"Abigail",
|
||||||
|
"Amelia",
|
||||||
|
"Eleanor",
|
||||||
|
"Grace",
|
||||||
|
"Alice",
|
||||||
|
"Lucy",
|
||||||
|
"Chloe",
|
||||||
|
"Sophie",
|
||||||
|
"Lily",
|
||||||
|
"Hannah",
|
||||||
|
"Zoe",
|
||||||
|
"Luna",
|
||||||
|
"Nova",
|
||||||
|
"Aria",
|
||||||
|
"Willow",
|
||||||
|
"Aurora",
|
||||||
|
"Sage",
|
||||||
|
"River",
|
||||||
|
"Winter",
|
||||||
|
"Sky",
|
||||||
|
"Rain",
|
||||||
|
]
|
||||||
|
)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,13 @@ Logic tasks for training reasoning capabilities:
|
||||||
- Syllogisms
|
- Syllogisms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
|
||||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
||||||
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AliceInWonderlandConfig",
|
||||||
|
"AliceInWonderlandDataset",
|
||||||
"PropositionalLogicConfig",
|
"PropositionalLogicConfig",
|
||||||
"PropositionalLogicDataset",
|
"PropositionalLogicDataset",
|
||||||
"SyllogismConfig",
|
"SyllogismConfig",
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from random import Random
|
from random import Random
|
||||||
from string import Template
|
from string import Template
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
class TaskType(Enum):
|
||||||
"""Defines the type of task for the Alice in Wonderland dataset."""
|
"""Defines the type of task for the Alice in Wonderland dataset."""
|
||||||
|
|
||||||
SIBLINGS = "siblings"
|
SIBLINGS = "siblings"
|
||||||
FRIENDS = "friends"
|
FRIENDS = "friends"
|
||||||
COLLEAGUES = "colleagues" # Added colleagues task
|
COLLEAGUES = "colleagues" # Added colleagues task
|
||||||
|
|
@ -26,21 +27,39 @@ class AliceInWonderlandConfig:
|
||||||
size (int): Number of samples in the dataset.
|
size (int): Number of samples in the dataset.
|
||||||
max_entities (int): Max number of siblings/friends/colleagues in questions.
|
max_entities (int): Max number of siblings/friends/colleagues in questions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
male_names: List[str] = field(
|
male_names: List[str] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
"James", "John", "Robert", "Michael", "William", "David",
|
"James",
|
||||||
"Richard", "Joseph", "Thomas", "Charles", "Bob"
|
"John",
|
||||||
|
"Robert",
|
||||||
|
"Michael",
|
||||||
|
"William",
|
||||||
|
"David",
|
||||||
|
"Richard",
|
||||||
|
"Joseph",
|
||||||
|
"Thomas",
|
||||||
|
"Charles",
|
||||||
|
"Bob",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
female_names: List[str] = field(
|
female_names: List[str] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth",
|
"Mary",
|
||||||
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice"
|
"Patricia",
|
||||||
|
"Jennifer",
|
||||||
|
"Linda",
|
||||||
|
"Elizabeth",
|
||||||
|
"Barbara",
|
||||||
|
"Susan",
|
||||||
|
"Jessica",
|
||||||
|
"Sarah",
|
||||||
|
"Margaret",
|
||||||
|
"Alice",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
task_types: List[TaskType] = field(
|
task_types: List[TaskType] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
||||||
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
|
||||||
)
|
)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 10
|
size: int = 10
|
||||||
|
|
@ -57,19 +76,19 @@ class AliceInWonderlandConfig:
|
||||||
|
|
||||||
class AliceInWonderlandDataset(ProceduralDataset):
|
class AliceInWonderlandDataset(ProceduralDataset):
|
||||||
"""
|
"""
|
||||||
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
||||||
|
|
||||||
The dataset is inspired by the following paper:
|
The dataset is inspired by the following paper:
|
||||||
@inproceedings{nezhurina2024alice,
|
@inproceedings{nezhurina2024alice,
|
||||||
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
||||||
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
||||||
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
||||||
Jenia Jitsev},
|
Jenia Jitsev},
|
||||||
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
||||||
Deep Learning},
|
Deep Learning},
|
||||||
year={2024},
|
year={2024},
|
||||||
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -152,14 +171,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
||||||
num_female=num_female,
|
num_female=num_female,
|
||||||
)
|
)
|
||||||
elif task_type == TaskType.COLLEAGUES:
|
elif task_type == TaskType.COLLEAGUES:
|
||||||
num_male_colleagues_alice_circle = rng.randint(
|
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||||
1, self.config.max_entities)
|
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||||
num_female_colleagues_alice_circle = rng.randint(
|
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||||
1, self.config.max_entities)
|
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||||
num_male_colleagues_bob_circle = rng.randint(
|
|
||||||
1, self.config.max_entities)
|
|
||||||
num_female_colleagues_bob_circle = rng.randint(
|
|
||||||
1, self.config.max_entities)
|
|
||||||
|
|
||||||
answer = num_female_colleagues_alice_circle + 1
|
answer = num_female_colleagues_alice_circle + 1
|
||||||
template = rng.choice(self.templates[TaskType.COLLEAGUES])
|
template = rng.choice(self.templates[TaskType.COLLEAGUES])
|
||||||
|
|
@ -169,16 +184,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
||||||
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
|
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
|
||||||
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
|
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
|
||||||
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
|
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
|
||||||
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
|
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}
|
||||||
"question": question,
|
|
||||||
"answer": answer,
|
|
||||||
"metadata": {
|
|
||||||
"task_type": task_type.value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ def test_aiw_config_validation():
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = AliceInWonderlandConfig(
|
config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice
|
||||||
female_names=["Mary", "Jane"]) # No Alice
|
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
|
@ -56,8 +55,7 @@ def test_aiw_items():
|
||||||
# Verify question task type characteristics
|
# Verify question task type characteristics
|
||||||
task_type = item["metadata"]["task_type"]
|
task_type = item["metadata"]["task_type"]
|
||||||
if task_type == TaskType.SIBLINGS.value:
|
if task_type == TaskType.SIBLINGS.value:
|
||||||
assert any(phrase in item["question"]
|
assert any(phrase in item["question"] for phrase in ["brothers", "sisters"])
|
||||||
for phrase in ["brothers", "sisters"])
|
|
||||||
elif task_type == TaskType.FRIENDS.value:
|
elif task_type == TaskType.FRIENDS.value:
|
||||||
assert "friends" in item["question"]
|
assert "friends" in item["question"]
|
||||||
elif task_type == TaskType.COLLEAGUES:
|
elif task_type == TaskType.COLLEAGUES:
|
||||||
|
|
@ -95,5 +93,4 @@ def test_aiw_random_ranges():
|
||||||
numbers = [int(n) for n in question.split() if n.isdigit()]
|
numbers = [int(n) for n in question.split() if n.isdigit()]
|
||||||
|
|
||||||
# Check all numbers are in reasonable range (1-6 as per implementation)
|
# Check all numbers are in reasonable range (1-6 as per implementation)
|
||||||
assert all(
|
assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
|
||||||
1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue