diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py index c17e8751..ce2db322 100644 --- a/reasoning_gym/algorithmic/palindrome_generation.py +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -5,6 +5,23 @@ from typing import Any, Dict, Optional from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome. + +A palindrome is a phrase that reads the same forwards and backwards. + +If there are multiple possible answers, only respond with one of them. You must use all the letters provided. + +Example: +- Input: Form a valid palindrome using the following letters: a, a, b +- Output: aba +- Explanation: + - The phrase aba reads the same forwards and backwards. + - The output answer is a valid palindrome using all the letters provided. + - The answer is a string, rather than a list of characters. + +Now, form a valid palindrome using the following letters: {letters} +""" + @dataclass class PalindromeConfig: @@ -51,16 +68,8 @@ class PalindromeDataset(ProceduralDataset): letters = self._generate_palindrome_letters(rng, length) scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order palindrome = self._assemble_palindrome(letters) - - question_str = ( - "Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward. If there are multiple answers, only respond with one of them.\n\n" - "For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n" - f"Your letters: {', '.join(scrambled_letters)}\n\n" - "What palindrome can you form from these letters?" - ) - return { - "question": question_str, + "question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)), "answer": palindrome, "metadata": { "letters": scrambled_letters, diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index b96698bb..98c3f000 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -27,6 +27,7 @@ class ArcAgiConfig: default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"] ) # empty list for no mirrors use_color_permutation: bool = True + shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle seed: Optional[int] = None size: int = 500 @@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board: return [[colors[c] for c in row] for row in board] -ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] -MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] +# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] +# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] class ArcAgiDataset(ProceduralDataset): @@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset): for p in train: augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])}) + if self.config.shuffle_example_order: + rng.shuffle(augmented_train) + examples = [ format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(augmented_train) diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 58b62b1a..e2278b1c 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -54,14 +54,29 @@ ANIMALS = { "woodlouse": 14, } +QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals. + +Example: +- Input: How many legs are there in total if you have 1 duck, 2 deers, 1 spider, 3 cows? +- Output: 30 +- Explanation: + - Ducks have 2 legs each, so 1 duck has 2 legs. + - Deers have 4 legs each, so 2 deers have 8 legs. + - Spiders have 8 legs each, so 1 spider has 8 legs. + - Cows have 4 legs each, so 3 cows have 12 legs. + - Therefore, the total number of legs is 2 + 8 + 8 + 12 = 30 + +Now, how many legs are there in total if you have {animals}? +""" + @dataclass class LegCountingConfig: """Configuration for leg counting task generation""" - min_animals: int = 2 # Minimum number of animals in problem - max_animals: int = 5 # Maximum number of animals - max_instances: int = 3 # Maximum instances of each animal + min_animals: int = 3 # Minimum number of animals in problem + max_animals: int = 10 # Maximum number of animals + max_instances: int = 15 # Maximum instances of each animal seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -106,10 +121,8 @@ class LegCountingDataset(ProceduralDataset): for animal, count in animals.items(): animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}") - question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?" - return { - "question": question, + "question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)), "answer": str(total_legs), "metadata": { "difficulty": { diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index da43e6ab..cfeecf21 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes(): both_ds = ArcAgiDataset(both_config) assert len(both_ds._task_ids) > len(train_ds._task_ids) assert len(both_ds._task_ids) > len(eval_ds._task_ids) + + +def test_arc_agi_shuffled_order(): + config_unshuffled = ArcAgiConfig( + shuffle_example_order=False, + use_train=True, + use_eval=False, + rotations=[], + mirrors=[], + use_color_permutation=False, + size=3, + seed=42, + ) + config_shuffled = ArcAgiConfig( + shuffle_example_order=True, + use_train=True, + use_eval=False, + rotations=[], + mirrors=[], + use_color_permutation=False, + size=3, + seed=42, + ) + unshuffled = ArcAgiDataset(config_unshuffled) + shuffled = ArcAgiDataset(config_shuffled) + + for a, b in zip(shuffled, unshuffled): + assert a["question"] != b["question"] + assert a["answer"] == b["answer"]