diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py new file mode 100644 index 00000000..8b25e98f --- /dev/null +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -0,0 +1,117 @@ +"""Leg counting task generator""" +from dataclasses import dataclass +from random import Random +from typing import Dict, List, Optional, Tuple + +ANIMALS = { + "spider": 8, + "insect": 6, + "dog": 4, + "chicken": 2, + "snake": 0, + "cat": 4, + "bird": 2, + "cow": 4, + "ant": 6, + "scorpion": 8, + "human": 2, + "horse": 4, + "duck": 2, + "butterfly": 6, + "centipede": 100, +} + +@dataclass +class LegCountingConfig: + """Configuration for leg counting task generation""" + min_animals: int = 2 # Minimum number of animals in problem + max_animals: int = 5 # Maximum number of animals + max_instances: int = 3 # Maximum instances of each animal + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self): + """Validate configuration parameters""" + assert self.min_animals > 0, "min_animals must be positive" + assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals" + assert self.max_instances > 0, "max_instances must be positive" + + +class LegCountingDataset: + """Generates leg counting arithmetic tasks""" + + def __init__(self, config: LegCountingConfig): + self.config = config + self.config.validate() + self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _generate_animals(self, rng: Random) -> Dict[str, int]: + """Generate a random set of animals and their counts""" + num_types = rng.randint(self.config.min_animals, self.config.max_animals) + animals = {} + + # Select random animals + selected_animals = rng.sample(list(ANIMALS.keys()), num_types) + for animal in selected_animals: + count = rng.randint(1, self.config.max_instances) + animals[animal] = count + + return animals + + def __getitem__(self, idx: int) -> dict: + """Generate a single leg counting task""" + rng = Random(self.seed + idx) + + # Generate random animals and their counts + animals = self._generate_animals(rng) + + # Calculate total legs + total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items()) + + # Format animal counts for question + animal_list = [] + for animal, count in animals.items(): + animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}") + + question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?" + + return { + "question": question, + "answer": str(total_legs), + "metadata": { + "animals": animals, + "total_legs": total_legs + } + } + + +def leg_counting_dataset( + min_animals: int = 2, + max_animals: int = 5, + max_instances: int = 3, + seed: Optional[int] = None, + size: int = 500, +) -> LegCountingDataset: + """Create a LegCountingDataset with the given configuration.""" + config = LegCountingConfig( + min_animals=min_animals, + max_animals=max_animals, + max_instances=max_instances, + seed=seed, + size=size, + ) + return LegCountingDataset(config)