diff --git a/play.py b/play.py new file mode 100644 index 00000000..77d3820f --- /dev/null +++ b/play.py @@ -0,0 +1,6 @@ +import reasoning_gym + +data = reasoning_gym.create_dataset("list_functions", size=3, seed=42) +for i, x in enumerate(data): + print(f"{i}: q={x['question']}, a={x['answer']}") + print("metadata:", x["metadata"]) diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 9429682d..c0c90c82 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,7 +2,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import algebra, algorithmic, arc, arithmetic, code, cognition, data, games, geometry, graphs, logic +from . import algebra, algorithmic, arc, arithmetic, code, cognition, data, games, geometry, graphs, induction, logic from .factory import create_dataset, register_dataset __version__ = "0.1.8" @@ -18,6 +18,7 @@ __all__ = [ "geometry", "graphs", "logic", + "induction", "create_dataset", "register_dataset", ] diff --git a/reasoning_gym/induction/__init__.py b/reasoning_gym/induction/__init__.py new file mode 100644 index 00000000..f6dc1504 --- /dev/null +++ b/reasoning_gym/induction/__init__.py @@ -0,0 +1,10 @@ +""" +Arithmetic tasks for training reasoning capabilities: +""" + +from .list_functions import ListFunctionsDataset, ListFunctionsDatasetConfig + +__all__ = [ + "ListFunctionsDataset", + "ListFunctionsDatasetConfig", +] diff --git a/reasoning_gym/induction/list_functions/__init__.py b/reasoning_gym/induction/list_functions/__init__.py new file mode 100644 index 00000000..d99a03b9 --- /dev/null +++ b/reasoning_gym/induction/list_functions/__init__.py @@ -0,0 +1,6 @@ +from .list_functions import ListFunctionsDataset, ListFunctionsDatasetConfig + +__all__ = [ + "ListFunctionsDatasetConfig", + "ListFunctionsDataset", +] diff --git a/reasoning_gym/induction/list_functions/generators.py b/reasoning_gym/induction/list_functions/generators.py new file mode 100644 index 00000000..3e1966b8 --- /dev/null +++ b/reasoning_gym/induction/list_functions/generators.py @@ -0,0 +1,363 @@ +import random +from random import Random +from typing import Any, Dict + +NUM_OF_PAIRS_GENERATED = 5 + + +def create_random_list(rng: Random): + length = rng.randint(3, 10) + return [rng.randint(1, 100) for _ in range(length)] + + +def create_list_of_fives(rng: Random): + length = rng.randint(1, 7) # Random length between 1 and 7 + return [5] * length + + +def sort_integers(lst, order="ascending"): + """ + Sorts a list of integers in ascending or descending order. + + Parameters: + lst (list): The list of integers to sort. + order (str): The order to sort in. Options are 'ascending' or 'descending'. + + Returns: + list: The sorted list. + """ + if order == "ascending": + return sorted(lst) # Sort in ascending order + elif order == "descending": + return sorted(lst, reverse=True) # Sort in descending order + else: + raise ValueError("Invalid order. Use 'ascending' or 'descending'.") + + +def create_random_odd_numbers(count, start, end): + """ + Generates a list of random odd numbers. + + Parameters: + count (int): The number of odd numbers to generate. + start (int): The lower bound of the range (inclusive). + end (int): The upper bound of the range (inclusive). + + Returns: + list: A list of random odd numbers. + """ + odd_numbers = [] + while len(odd_numbers) < count: + num = random.randint(start, end) # Generate a random number + if num % 2 != 0: # Check if the number is odd + odd_numbers.append(num) + return odd_numbers + + +def create_numbers_divisible_by_five_or_ten(rng: Random): + result = [] + for i in range(NUM_OF_PAIRS_GENERATED): + if i % 2 == 0: + num = create_random_odd_numbers(1, 1, 1000)[0] * 5 # Random multiple of 5 + else: + num = rng.randint(1, 100) * 10 # Random multiple of 10 + result.append(num) + return result + + +def generate_0(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where input remains unchanged""" + pairs = {} + + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + input = str(input) + output = input + pairs[input] = output + + return pairs + + +def generate_1(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of the third element + after removing all other elements + """ + pairs = {} + + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + target_idx = 2 + output = [input[target_idx]] + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_2(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a reversed list of the input""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = input[::-1] + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_3(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is the sum of unique elements in the list less than 30""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + unique_input = list(set(input)) + + total_sum = 0 + for num in unique_input: + if num < 30: + total_sum += num + + input = str(input) + output = str([total_sum]) + pairs[input] = output + + return pairs + + +def generate_4(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is the count of elements equal to 5""" + pairs = {} + for i in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + + if i % 2 == 0: + input += create_list_of_fives(rng) + + # Shuffle the new input with fives + rng.shuffle(input) + + total_count = 0 + for num in input: + if num == 5: + total_count += 1 + + input = str(input) + output = str([total_count]) + pairs[input] = output + + return pairs + + +def generate_5(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of elements that are followed by an even number + + NOTE: This is suppose to be a relatively hard problem + """ + pairs = {} + for i in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = [] + for i in range(1, len(input)): + + # If the current element is an even number, we then add previous element into output + if input[i] % 2 == 0: + output.append(input[i - 1]) + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_6(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of elements where each element in input is added to its position(Using zero-indexing)""" + pairs = {} + for i in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = [] + for i, num in enumerate(input): + element = i + num + output.append(element) + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_7(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of element whose position is indicated by the last element in the input + + EXAMPLE: + 1. [26, 88, 60, 1, 17, 75, 97, 89, 1] -> [88] + 2. [49, 71, 2, 61, 3]: [61] + """ + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + # Create a chosen index between the bounds of the size of the input + chosen_index = rng.randint(0, len(input) - 1) + # Replace the last element with chosen_index + input[-1] = chosen_index + output = [input[chosen_index]] + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_8(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is count of elements in the input""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = [len(input)] + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_9(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is sum total of elements in the input""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = [sum(input)] + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_10(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of the elements in ascending order""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = sort_integers(input, order="ascending") + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_11(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of the elements in descending order""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + output = sort_integers(input, order="descending") + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_12(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a list of the elements where the first and last element in input are replaced by their + successor. Example, for an integer 4, successor is 5 + """ + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + # Create successor for first and last element using a copy of input + output = input.copy() + output[0] += 1 + output[-1] += 1 + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_13(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is [1] if list of input elements is in ascending order, [0] in descending order""" + pairs = {} + for i in range(NUM_OF_PAIRS_GENERATED): + input = create_random_list(rng) + if i % 2 == 0: + input = sort_integers(input, order="ascending") + output = [1] + else: + input = sort_integers(input, order="descending") + output = [0] + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_14(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is [1] if input element is divisible by 10, [0] if divisible by 5""" + pairs = {} + + nums = create_numbers_divisible_by_five_or_ten(rng) + for num in nums: + if num % 10 == 0: + input = [num] + output = [1] + else: + input = [num] + output = [0] + + input = str(input) + output = str(output) + pairs[input] = output + + return pairs + + +def generate_15(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is a twice the amount of last element in the input""" + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + starter_input = create_random_list(rng) + length = len(starter_input) + first_element = rng.choice(starter_input) + input = [first_element] + + for _ in range(1, length): + prev = input[-1] + input.append(prev * 2) + + # Create output here to prevent building on strings + output = str([input[-1] * 2]) + input = str(input) + pairs[input] = output + + return pairs + + +def generate_16(rng: Random) -> Dict[str, Any]: + """Generate input and output pairs where output is built from a function 2x - 4 + NOTE: This is suppose to be amazingly hard for the LLM. + """ + pairs = {} + for _ in range(NUM_OF_PAIRS_GENERATED): + starter_input = create_random_list(rng) + first_element = rng.choice(starter_input) + output = (2 * first_element) - 4 + input = str([first_element]) + pairs[input] = str([output]) + + return pairs diff --git a/reasoning_gym/induction/list_functions/list_functions.py b/reasoning_gym/induction/list_functions/list_functions.py new file mode 100644 index 00000000..cf2d98e7 --- /dev/null +++ b/reasoning_gym/induction/list_functions/list_functions.py @@ -0,0 +1,99 @@ +"""List functions generators""" + +from dataclasses import dataclass +from random import Random +from typing import Any, Callable, Optional + +from reasoning_gym.factory import ProceduralDataset, register_dataset + + +@dataclass +class ListFunctionsDatasetConfig: + """Configuration for List function generators.""" + + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + + +tasks = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, +] + + +class ListFunctionsDataset(ProceduralDataset): + + def __init__(self, config: ListFunctionsDatasetConfig): + super().__init__(config, config.seed, config.size) + self._generators: dict[int, Callable[[Random, float], dict[str, Any]]] = None # initially None, lazy loading + self.task_indices = Random(self.seed).choices(tasks, k=self.size) + self.prompt_template = """You are an expert at inductive reasoning. Generate an output corresponding to the given input. +The output is generated by applying the same rule that maps input to output for the examples provided. Your answer should be a list of element/elements +Examples: +{examples} + +Input: {input} +Output: +""" + + @property + def generators(self) -> dict[int, Callable[[Random, float], dict[str, Any]]]: + """Lazy load generators only when first accessed""" + if self._generators is None: + self._generators = self._load_generators() + return self._generators + + def _load_generators(self): + """ + Generates mapper from task identifiers (keys) to example generator functions + """ + from . import generators + + def strip_prefix(s: str, prefix: str) -> str: + return s[len(prefix) :] + + prefix = "generate_" + gs = {} + for n in dir(generators): + if n.startswith(prefix): + gs[int(strip_prefix(n, prefix))] = getattr(generators, n) + return gs + + def __getitem__(self, idx: int) -> dict: + """Generate a single induction-based list function dataset""" + rng = Random(self.seed + idx) + generator_idx = self.task_indices[idx] + generator = self.generators[generator_idx] + examples = generator(rng) + entry = examples.popitem() + input = entry[0] + output = entry[1] + formatted_examples = "" + for index, key in enumerate(examples): + formatted_examples += f"""Input {index + 1}: {key} +Output {index + 1}: {examples[key]} +""" + question = self.prompt_template.format(examples=formatted_examples, input=input) + return {"question": question, "answer": output, "metadata": {}} + + +register_dataset("list_functions", ListFunctionsDataset, ListFunctionsDatasetConfig) diff --git a/tests/test_list_functions.py b/tests/test_list_functions.py new file mode 100644 index 00000000..6e7ee0c0 --- /dev/null +++ b/tests/test_list_functions.py @@ -0,0 +1,84 @@ +from random import Random + +import pytest + +from reasoning_gym.induction.list_functions import ListFunctionsDataset, ListFunctionsDatasetConfig + + +def test_list_functions_config_validation(): + """Test that config validation works""" + config = ListFunctionsDatasetConfig(size=-1) + with pytest.raises(AssertionError): + config.validate() + + +def test_list_functions_deterministic(): + """Test that dataset generates same items with same seed""" + config = ListFunctionsDatasetConfig(seed=42, size=10) + dataset1 = ListFunctionsDataset(config) + dataset2 = ListFunctionsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_list_functions_items(): + """Test basic properties of generated items""" + config = ListFunctionsDatasetConfig(size=50, seed=42) + dataset = ListFunctionsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert isinstance(item["question"], str) + assert isinstance(item["answer"], str) + + +def test_list_functions_iteration(): + """Test that iteration respects dataset size""" + config = ListFunctionsDatasetConfig(size=5, seed=42) # Small size for testing + dataset = ListFunctionsDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_list_functions_generators(): + """Test generator loading and access""" + config = ListFunctionsDatasetConfig() + dataset = ListFunctionsDataset(config) + + # Test lazy loading + assert dataset._generators is None + _ = dataset.generators # Access to trigger loading + assert dataset._generators is not None + + # Test generator mapping + assert isinstance(dataset.generators, dict) + assert len(dataset.generators) > 0 + i = 0 + rng = Random(18) + for key in sorted(dataset.generators.keys()): + generator = dataset.generators[key] + assert callable(generator) + + print(i, key) + for _ in range(10): + x = generator(rng) + assert isinstance(x, dict) + + i += 1