diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 816b5ae3..2ce3a13b 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -29,7 +29,6 @@ class AliceInWonderlandConfig: male_names (List[str]): List of male names to use in questions. female_names (List[str]): List of female names to use in questions. Must include 'Alice'. task_types (List[TaskType]): List of task types to include in dataset. - output_formats (List[OutputFormat]): List of output formats to include in dataset. seed (Optional[int]): Seed for random number generation. size (int): Number of samples in the dataset. max_entities (int): Max number of siblings/friends/colleagues in questions. @@ -47,14 +46,8 @@ class AliceInWonderlandConfig: ] ) task_types: List[TaskType] = field( - default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues - ) - output_formats: List[OutputFormat] = field( default_factory=lambda: [ - OutputFormat.PLAIN, - OutputFormat.RESTRICTED, - OutputFormat.THINKING, - ] + TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues ) seed: Optional[int] = None size: int = 10 @@ -66,8 +59,6 @@ class AliceInWonderlandConfig: assert len(self.female_names) > 0, "must provide female names" assert "Alice" in self.female_names, "'Alice' must be in female names" assert len(self.task_types) > 0, "must provide at least one task type" - assert len( - self.output_formats) > 0, "must provide at least one output format" assert self.max_entities > 0, "max_entities must be positive" @@ -88,6 +79,7 @@ class AliceInWonderlandDataset(ProceduralDataset): } """ + def __init__(self, config: AliceInWonderlandConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.templates = { @@ -118,12 +110,12 @@ class AliceInWonderlandDataset(ProceduralDataset): "$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. " "All these mentioned persons around $female_name are colleagues of each other. " "$male_name has $num_male_colleagues_bob_circle male colleagues " - "and $num_female_colleagues_bob_circle female colleagues in total. " + "and $num_female_colleagues_bob_circle female colleagues in total. " "All these mentioned persons around $male_name are colleagues of each other. " - "The people in the circle around $male_name do not have " - "other colleagues aside - with the only exception of Matilda. " + "The people in the circle around $male_name do not have " + "other colleagues aside - with the only exception of Matilda. " "She is colleague of $male_name and she is also colleague of $female_name, " - "being part of $female_name's circle. How many female colleagues does Matilda have?" + "being part of $female_name's circle. How many female colleagues does Matilda have?" ), ], } @@ -153,7 +145,6 @@ class AliceInWonderlandDataset(ProceduralDataset): and a description of the example. """ task_type = rng.choice(self.config.task_types) - output_format = rng.choice(self.config.output_formats) female_name = rng.choice(self.config.female_names) male_name = rng.choice(self.config.male_names) @@ -182,10 +173,14 @@ class AliceInWonderlandDataset(ProceduralDataset): num_female=num_female, ) elif task_type == TaskType.COLLEAGUES: - num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities) - num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities) - num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities) - num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + num_male_colleagues_alice_circle = rng.randint( + 1, self.config.max_entities) + num_female_colleagues_alice_circle = rng.randint( + 1, self.config.max_entities) + num_male_colleagues_bob_circle = rng.randint( + 1, self.config.max_entities) + num_female_colleagues_bob_circle = rng.randint( + 1, self.config.max_entities) answer = num_female_colleagues_alice_circle + 1 template = rng.choice(self.templates[TaskType.COLLEAGUES]) @@ -198,14 +193,12 @@ class AliceInWonderlandDataset(ProceduralDataset): num_female_colleagues_bob_circle=num_female_colleagues_bob_circle ) - formatted_question = self.format_templates[output_format].substitute( - question=question - ) - return { - "prompt": formatted_question, - "right_answer": str(answer), - "description": f"{task_type.value} variation, {output_format.value} format", + "question": question, + "answer": answer, + "metadata": { + "task_type": task_type.value + } } def __getitem__(self, idx: int) -> dict: @@ -213,4 +206,4 @@ class AliceInWonderlandDataset(ProceduralDataset): return self._get_aiw(rng) -register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) \ No newline at end of file +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) diff --git a/tests/test_aiw.py b/tests/test_aiw.py index 5a0fbaf5..279fcc2c 100644 --- a/tests/test_aiw.py +++ b/tests/test_aiw.py @@ -2,6 +2,7 @@ import pytest from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat + def test_aiw_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): @@ -11,18 +12,16 @@ def test_aiw_config_validation(): with pytest.raises(AssertionError): config = AliceInWonderlandConfig(female_names=[]) # Empty female names config.validate() - + with pytest.raises(AssertionError): - config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice + config = AliceInWonderlandConfig( + female_names=["Mary", "Jane"]) # No Alice config.validate() - + with pytest.raises(AssertionError): config = AliceInWonderlandConfig(task_types=[]) # No task types config.validate() - - with pytest.raises(AssertionError): - config = AliceInWonderlandConfig(output_formats=[]) # No output formats - config.validate() + def test_aiw_deterministic(): """Test that dataset generates same items with same seed""" @@ -33,6 +32,7 @@ def test_aiw_deterministic(): for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] + def test_aiw_items(): """Test basic properties of generated items""" config = AliceInWonderlandConfig(size=50, seed=42) @@ -41,29 +41,28 @@ def test_aiw_items(): for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) - assert "prompt" in item - assert "right_answer" in item - assert "description" in item - + assert "question" in item + assert "answer" in item + assert "metadata" in item + # Verify answer is numeric and positive - answer = int(item["right_answer"]) + answer = int(item["answer"]) assert answer > 0 - + # Verify question contains at least one female name female_names = config.female_names - assert any(name in item["prompt"] for name in female_names) + assert any(name in item["question"] for name in female_names) + + # Verify question task type characteristics + task_type = item["metadata"]["task_type"] + if task_type == TaskType.SIBLINGS.value: + assert any(phrase in item["question"] + for phrase in ["brothers", "sisters"]) + elif task_type == TaskType.FRIENDS.value: + assert "friends" in item["question"] + elif task_type == TaskType.COLLEAGUES: + assert "colleagues" in item["question"] - # Verify question format - if TaskType.SIBLINGS.value in item["description"]: - assert any(phrase in item["prompt"] for phrase in ["brothers", "sisters"]) - elif TaskType.FRIENDS.value in item["description"]: - assert "friends" in item["prompt"] - - # Verify output format - if OutputFormat.RESTRICTED.value in item["description"]: - assert "DO NOT OUTPUT ANY TEXT EXCEPT" in item["prompt"] - elif OutputFormat.THINKING.value in item["description"]: - assert "think carefully step by step" in item["prompt"] def test_aiw_iteration(): """Test that iteration works correctly""" @@ -85,23 +84,16 @@ def test_aiw_iteration(): second_items = list(dataset) assert first_items == second_items + def test_aiw_random_ranges(): """Test that generated numbers stay within expected ranges""" config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12) dataset = AliceInWonderlandDataset(config) for item in dataset: - prompt = item["prompt"] - numbers = [int(n) for n in prompt.split() if n.isdigit()] - + question = item["question"] + numbers = [int(n) for n in question.split() if n.isdigit()] + # Check all numbers are in reasonable range (1-6 as per implementation) - assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" - -def test_output_format_is_correct(): - """Test that the output format adheres to the user input""" - config = AliceInWonderlandConfig(size=30, seed=42, output_formats=[OutputFormat.THINKING]) - dataset = AliceInWonderlandDataset(config) - - for item in dataset: - prompt = item["prompt"] - assert "think carefully step by step" in item["prompt"] + assert all( + 1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"