post merge lint

This commit is contained in:
Andreas Koepf 2025-02-02 10:04:18 +01:00
parent 02cfa9556a
commit f396d3df60
6 changed files with 197 additions and 63 deletions

View file

@ -3,6 +3,7 @@ This gallery shows examples from all available datasets using their default conf
## Available Datasets ## Available Datasets
- [advanced_geometry](#advanced_geometry) - [advanced_geometry](#advanced_geometry)
- [aiw](#aiw)
- [base_conversion](#base_conversion) - [base_conversion](#base_conversion)
- [basic_arithmetic](#basic_arithmetic) - [basic_arithmetic](#basic_arithmetic)
- [bf](#bf) - [bf](#bf)
@ -73,6 +74,50 @@ Metadata: {'A': (6, 7), 'B': (-7, -5), 'C': (2, -3), 'incircle_radius_exact': 's
```` ````
### aiw
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
Default configuration:
```python
male_names = ['James', 'John', 'Robert', 'Michael', 'William', 'David', 'Richard', 'Joseph', 'Thomas', 'Charles', 'Bob']
female_names = ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah', 'Margaret', 'Alice']
task_types = [<TaskType.SIBLINGS: 'siblings'>, <TaskType.FRIENDS: 'friends'>, <TaskType.COLLEAGUES: 'colleagues'>]
seed = 42
size = 10
max_entities = 6
```
Example tasks:
````
Example 1:
Question: Patricia has 6 male colleagues and she also has 3 female colleagues. These are all colleagues that Patricia has. All these mentioned persons around Patricia are colleagues of each other. James has 2 male colleagues and 2 female colleagues in total. All these mentioned persons around James are colleagues of each other. The people in the circle around James do not have other colleagues aside - with the only exception of Matilda. She is colleague of James and she is also colleague of Patricia, being part of Patricia's circle. How many female colleagues does Matilda have?
Answer: 4
Metadata: {'task_type': 'colleagues'}
Example 2:
Question: Elizabeth has 4 brothers and she also has 3 sisters. How many sisters does Elizabeth's brother have?
Answer: 4
Metadata: {'task_type': 'siblings'}
Example 3:
Question: Sarah has 6 male friends and she also has 1 female friends. They all are friends with each other and have no other friends aside. How many female friends does Thomas, a male friend of Sarah, have?
Answer: 2
Metadata: {'task_type': 'friends'}
````
### base_conversion ### base_conversion
Generates base conversion tasks Generates base conversion tasks
@ -1548,7 +1593,7 @@ Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6,
Example 2: Example 2:
Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM. Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM.
Answer: 02:38 Answer: 02:38
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 1, 9, 44), 'end_time': datetime.datetime(2025, 2, 1, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'} Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 2, 9, 44), 'end_time': datetime.datetime(2025, 2, 2, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
Example 3: Example 3:
Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days. Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days.

View file

@ -21,11 +21,13 @@ class AdvancedGeometryConfig:
# Probability or list of tasks we want to generate # Probability or list of tasks we want to generate
# For demonstration, we have three categories: # For demonstration, we have three categories:
task_types: List[str] = field(default_factory=lambda: [ task_types: List[str] = field(
default_factory=lambda: [
"orthocenter", "orthocenter",
"incircle_radius", "incircle_radius",
"angle_measure", "angle_measure",
]) ]
)
def validate(self): def validate(self):
assert self.min_coord < self.max_coord, "min_coord must be < max_coord." assert self.min_coord < self.max_coord, "min_coord must be < max_coord."

View file

@ -65,22 +65,100 @@ class FamilyRelationshipsConfig:
min_family_size: int = 4 min_family_size: int = 4
max_family_size: int = 8 max_family_size: int = 8
male_names: List[str] = field(default_factory=lambda: [ male_names: List[str] = field(
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph", default_factory=lambda: [
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew", "James",
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver", "John",
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan", "Robert",
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion", "Michael",
"Phoenix", "Atlas", "Axel", "Ryder", "Finn" "William",
]) "David",
female_names: List[str] = field(default_factory=lambda: [ "Richard",
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Joseph",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria", "Thomas",
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily", "Charles",
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe", "Peter",
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow", "Daniel",
"Aurora", "Sage", "River", "Winter", "Sky", "Rain" "Matthew",
]) "Christopher",
"Andrew",
"George",
"Edward",
"Benjamin",
"Henry",
"Samuel",
"Alexander",
"Oliver",
"Jack",
"Harry",
"Jacob",
"Noah",
"Ethan",
"Lucas",
"Mason",
"Logan",
"Sebastian",
"Theodore",
"Owen",
"Liam",
"Aiden",
"Kai",
"Jayden",
"Zion",
"Phoenix",
"Atlas",
"Axel",
"Ryder",
"Finn",
]
)
female_names: List[str] = field(
default_factory=lambda: [
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Emma",
"Lisa",
"Anna",
"Margaret",
"Victoria",
"Charlotte",
"Sophia",
"Isabella",
"Olivia",
"Ava",
"Mia",
"Emily",
"Abigail",
"Amelia",
"Eleanor",
"Grace",
"Alice",
"Lucy",
"Chloe",
"Sophie",
"Lily",
"Hannah",
"Zoe",
"Luna",
"Nova",
"Aria",
"Willow",
"Aurora",
"Sage",
"River",
"Winter",
"Sky",
"Rain",
]
)
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500

View file

@ -6,10 +6,13 @@ Logic tasks for training reasoning capabilities:
- Syllogisms - Syllogisms
""" """
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term from .syllogisms import SyllogismConfig, SyllogismDataset, Term
__all__ = [ __all__ = [
"AliceInWonderlandConfig",
"AliceInWonderlandDataset",
"PropositionalLogicConfig", "PropositionalLogicConfig",
"PropositionalLogicDataset", "PropositionalLogicDataset",
"SyllogismConfig", "SyllogismConfig",

View file

@ -1,14 +1,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum from enum import Enum
from random import Random from random import Random
from string import Template from string import Template
from typing import List, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
class TaskType(Enum): class TaskType(Enum):
"""Defines the type of task for the Alice in Wonderland dataset.""" """Defines the type of task for the Alice in Wonderland dataset."""
SIBLINGS = "siblings" SIBLINGS = "siblings"
FRIENDS = "friends" FRIENDS = "friends"
COLLEAGUES = "colleagues" # Added colleagues task COLLEAGUES = "colleagues" # Added colleagues task
@ -26,21 +27,39 @@ class AliceInWonderlandConfig:
size (int): Number of samples in the dataset. size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions. max_entities (int): Max number of siblings/friends/colleagues in questions.
""" """
male_names: List[str] = field( male_names: List[str] = field(
default_factory=lambda: [ default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David", "James",
"Richard", "Joseph", "Thomas", "Charles", "Bob" "John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Bob",
] ]
) )
female_names: List[str] = field( female_names: List[str] = field(
default_factory=lambda: [ default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Mary",
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice" "Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Margaret",
"Alice",
] ]
) )
task_types: List[TaskType] = field( task_types: List[TaskType] = field(
default_factory=lambda: [ default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
) )
seed: Optional[int] = None seed: Optional[int] = None
size: int = 10 size: int = 10
@ -152,14 +171,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_female=num_female, num_female=num_female,
) )
elif task_type == TaskType.COLLEAGUES: elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint( num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
1, self.config.max_entities) num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint( num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
1, self.config.max_entities) num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
answer = num_female_colleagues_alice_circle + 1 answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES]) template = rng.choice(self.templates[TaskType.COLLEAGUES])
@ -169,16 +184,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle, num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle, num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle, num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
) )
return { return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type.value
}
}
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx) rng = Random(self.seed + idx)

View file

@ -14,8 +14,7 @@ def test_aiw_config_validation():
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = AliceInWonderlandConfig( config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice
female_names=["Mary", "Jane"]) # No Alice
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
@ -56,8 +55,7 @@ def test_aiw_items():
# Verify question task type characteristics # Verify question task type characteristics
task_type = item["metadata"]["task_type"] task_type = item["metadata"]["task_type"]
if task_type == TaskType.SIBLINGS.value: if task_type == TaskType.SIBLINGS.value:
assert any(phrase in item["question"] assert any(phrase in item["question"] for phrase in ["brothers", "sisters"])
for phrase in ["brothers", "sisters"])
elif task_type == TaskType.FRIENDS.value: elif task_type == TaskType.FRIENDS.value:
assert "friends" in item["question"] assert "friends" in item["question"]
elif task_type == TaskType.COLLEAGUES: elif task_type == TaskType.COLLEAGUES:
@ -95,5 +93,4 @@ def test_aiw_random_ranges():
numbers = [int(n) for n in question.split() if n.isdigit()] numbers = [int(n) for n in question.split() if n.isdigit()]
# Check all numbers are in reasonable range (1-6 as per implementation) # Check all numbers are in reasonable range (1-6 as per implementation)
assert all( assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"