diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py new file mode 100644 index 00000000..bee47a89 --- /dev/null +++ b/reasoning_gym/logic/aiw.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass, field +from typing import List, Optional +from enum import Enum +from random import Random +from string import Template + +from ..factory import ProceduralDataset, register_dataset + + +class TaskType(Enum): + """Defines the type of task for the Alice in Wonderland dataset.""" + SIBLINGS = "siblings" + FRIENDS = "friends" + + +class OutputFormat(Enum): + """Defines the output format for the generated questions.""" + PLAIN = "plain" + RESTRICTED = "restricted" + THINKING = "thinking" + + +@dataclass +class AliceInWonderlandConfig: + """Configuration options for the Alice in Wonderland dataset. + + Attributes: + male_names (List[str]): List of male names to use in questions. + female_names (List[str]): List of female names to use in questions. Must include 'Alice'. + task_types (List[TaskType]): List of task types to include in dataset. + output_formats (List[OutputFormat]): List of output formats to include in dataset. + seed (Optional[int]): Seed for random number generation. + size (int): Number of samples in the dataset. + max_entities (int): Max number of siblings/friends in questions. + """ + male_names: List[str] = field( + default_factory=lambda: [ + "James", "John", "Robert", "Michael", "William", "David", + "Richard", "Joseph", "Thomas", "Charles", "Bob" + ] + ) + female_names: List[str] = field( + default_factory=lambda: [ + "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", + "Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice" + ] + ) + task_types: List[TaskType] = field( + default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS] + ) + output_formats: List[OutputFormat] = field( + default_factory=lambda: [ + OutputFormat.PLAIN, + OutputFormat.RESTRICTED, + OutputFormat.THINKING, + ] + ) + seed: Optional[int] = None + size: int = 10 + max_entities: int = 6 # Added max_entities + + def validate(self) -> None: + """Validates the configuration parameters.""" + assert len(self.male_names) > 0, "must provide male names" + assert len(self.female_names) > 0, "must provide female names" + assert "Alice" in self.female_names, "'Alice' must be in female names" + assert len(self.task_types) > 0, "must provide at least one task type" + assert len( + self.output_formats) > 0, "must provide at least one output format" + assert self.max_entities > 0, "max_entities must be positive" + + +class AliceInWonderlandDataset(ProceduralDataset): + """ + A procedural dataset inspired by the "Alice in Wonderland" paper. + + The dataset is inspired by the following paper: + @inproceedings{nezhurina2024alice, + title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and + Basic Reasoning Deficits in State-Of-the-Art Large Language Models}, + author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and + Jenia Jitsev}, + booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding + Deep Learning}, + year={2024}, + url={https://openreview.net/forum?id=Mkl7dzjYiW} + } + + """ + def __init__(self, config: AliceInWonderlandConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.templates = { + TaskType.SIBLINGS: [ + Template( + "$female_name has $num_brothers brothers and she also has " + "$num_sisters sisters. How many sisters does " + "$female_name's brother have?" + ), + Template( + "$female_name has $num_sisters sisters and she also has " + "$num_brothers brothers. How many sisters does " + "$male_name's brother have?" + ), + ], + TaskType.FRIENDS: [ + Template( + "$female_name has $num_male male friends and she also has " + "$num_female female friends. They all are friends with each " + "other and have no other friends aside. How many female " + "friends does $male_name, a male friend of $female_name, " + "have?" + ) + ], + } + + self.format_templates = { + OutputFormat.PLAIN: Template("$question"), + OutputFormat.RESTRICTED: Template( + "$question To answer the question, DO NOT OUTPUT ANY TEXT EXCEPT " + 'following format that contains final answer: "### Answer:"' + ), + OutputFormat.THINKING: Template( + "$question Before providing answer to this problem, think " + "carefully step by step and double check the path to the " + 'correct solution for any mistakes. Provide then the final ' + 'answer in following form: "### Answer:"' + ), + } + + def _get_aiw(self, rng: Random) -> dict: + """Generates a single Alice in Wonderland question. + + Args: + rng (Random): Random number generator. + + Returns: + dict: A dictionary containing the generated question, the right answer + and a description of the example. + """ + task_type = rng.choice(self.config.task_types) + output_format = rng.choice(self.config.output_formats) + female_name = rng.choice(self.config.female_names) + male_name = rng.choice(self.config.male_names) + + if task_type == TaskType.SIBLINGS: + num_brothers = rng.randint(1, self.config.max_entities) + num_sisters = rng.randint(1, self.config.max_entities) + answer = num_sisters + 1 + template = rng.choice(self.templates[TaskType.SIBLINGS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_brothers=num_brothers, + num_sisters=num_sisters, + ) + elif task_type == TaskType.FRIENDS: + num_male = rng.randint(1, self.config.max_entities) + num_female = rng.randint(1, self.config.max_entities) + answer = num_female + 1 + template = rng.choice(self.templates[TaskType.FRIENDS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_male=num_male, + num_female=num_female, + ) + + formatted_question = self.format_templates[output_format].substitute( + question=question + ) + + return { + "prompt": formatted_question, + "right_answer": str(answer), + "description": f"{task_type.value} variation, {output_format.value} format", + } + + def __getitem__(self, idx: int) -> dict: + rng = Random(self.seed + idx) + return self._get_aiw(rng) + + +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) \ No newline at end of file diff --git a/tests/test_aiw.py b/tests/test_aiw.py new file mode 100644 index 00000000..5a0fbaf5 --- /dev/null +++ b/tests/test_aiw.py @@ -0,0 +1,107 @@ +import pytest + +from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat + +def test_aiw_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(male_names=[]) # Empty male names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=[]) # Empty female names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(task_types=[]) # No task types + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(output_formats=[]) # No output formats + config.validate() + +def test_aiw_deterministic(): + """Test that dataset generates same items with same seed""" + config = AliceInWonderlandConfig(seed=42, size=10) + dataset1 = AliceInWonderlandDataset(config) + dataset2 = AliceInWonderlandDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + +def test_aiw_items(): + """Test basic properties of generated items""" + config = AliceInWonderlandConfig(size=50, seed=42) + dataset = AliceInWonderlandDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "prompt" in item + assert "right_answer" in item + assert "description" in item + + # Verify answer is numeric and positive + answer = int(item["right_answer"]) + assert answer > 0 + + # Verify question contains at least one female name + female_names = config.female_names + assert any(name in item["prompt"] for name in female_names) + + # Verify question format + if TaskType.SIBLINGS.value in item["description"]: + assert any(phrase in item["prompt"] for phrase in ["brothers", "sisters"]) + elif TaskType.FRIENDS.value in item["description"]: + assert "friends" in item["prompt"] + + # Verify output format + if OutputFormat.RESTRICTED.value in item["description"]: + assert "DO NOT OUTPUT ANY TEXT EXCEPT" in item["prompt"] + elif OutputFormat.THINKING.value in item["description"]: + assert "think carefully step by step" in item["prompt"] + +def test_aiw_iteration(): + """Test that iteration works correctly""" + config = AliceInWonderlandConfig(size=5, seed=42) + dataset = AliceInWonderlandDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size + + # Test list conversion + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items + +def test_aiw_random_ranges(): + """Test that generated numbers stay within expected ranges""" + config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12) + dataset = AliceInWonderlandDataset(config) + + for item in dataset: + prompt = item["prompt"] + numbers = [int(n) for n in prompt.split() if n.isdigit()] + + # Check all numbers are in reasonable range (1-6 as per implementation) + assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" + +def test_output_format_is_correct(): + """Test that the output format adheres to the user input""" + config = AliceInWonderlandConfig(size=30, seed=42, output_formats=[OutputFormat.THINKING]) + dataset = AliceInWonderlandDataset(config) + + for item in dataset: + prompt = item["prompt"] + assert "think carefully step by step" in item["prompt"]