post merge lint

This commit is contained in:
Andreas Koepf 2025-02-02 10:04:18 +01:00
parent 02cfa9556a
commit f396d3df60
6 changed files with 197 additions and 63 deletions

View file

@ -1,14 +1,15 @@
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
from random import Random
from string import Template
from typing import List, Optional
from ..factory import ProceduralDataset, register_dataset
class TaskType(Enum):
"""Defines the type of task for the Alice in Wonderland dataset."""
SIBLINGS = "siblings"
FRIENDS = "friends"
COLLEAGUES = "colleagues" # Added colleagues task
@ -26,21 +27,39 @@ class AliceInWonderlandConfig:
size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions.
"""
male_names: List[str] = field(
default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David",
"Richard", "Joseph", "Thomas", "Charles", "Bob"
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Bob",
]
)
female_names: List[str] = field(
default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth",
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice"
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Margaret",
"Alice",
]
)
task_types: List[TaskType] = field(
default_factory=lambda: [
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
seed: Optional[int] = None
size: int = 10
@ -57,19 +76,19 @@ class AliceInWonderlandConfig:
class AliceInWonderlandDataset(ProceduralDataset):
"""
A procedural dataset inspired by the "Alice in Wonderland" paper.
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
"""
@ -152,14 +171,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_female=num_female,
)
elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES])
@ -169,16 +184,10 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
)
return {
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type.value
}
}
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}
def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)