diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index ad76199f..550b67f6 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -6,6 +6,7 @@ Algorithmic tasks for training reasoning capabilities: - Pattern matching """ +from .ab import ABConfig, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset @@ -66,4 +67,6 @@ __all__ = [ "ManipulateMatrixDataset", "BinaryMatrixConfig", "BinaryMatrixDataset", + "ABConfig", + "ABDataset", ] diff --git a/reasoning_gym/algorithmic/ab.py b/reasoning_gym/algorithmic/ab.py new file mode 100644 index 00000000..c0070fb0 --- /dev/null +++ b/reasoning_gym/algorithmic/ab.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + + +def generate_program(length, rng): + """Generates a random initial program of a given length.""" + elements = ["A#", "B#", "#A", "#B"] + return [rng.choice(elements) for _ in range(length)] + + +def compute_steps(program, max_steps=100): + """Computes the transformation steps and detects if the program does not halt.""" + steps = [program.copy()] + seen_states = {tuple(program)} + + for step in range(max_steps): + current = steps[-1] + new_program = None + + for i in range(len(current) - 1): + a, b = current[i], current[i + 1] + if a == "A#" and b == "#A": + new_program = current[:i] + current[i + 2 :] + elif a == "A#" and b == "#B": + new_program = current[:i] + ["#B", "A#"] + current[i + 2 :] + elif a == "B#" and b == "#A": + new_program = current[:i] + ["#A", "B#"] + current[i + 2 :] + elif a == "B#" and b == "#B": + new_program = current[:i] + current[i + 2 :] + + if new_program is not None: + break + + if new_program is None: + # No more transformations possible + return steps, False + + if tuple(new_program) in seen_states: + # Detected a loop, meaning non-halting behavior + return steps, True + + steps.append(new_program) + seen_states.add(tuple(new_program)) + + return steps, True # Reached max steps, assume non-halting + + +@dataclass +class ABConfig: + """Configuration for A::B task generation""" + + seed: Optional[int] = None + size: int = 500 + length: int = 10 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.length > 0, "difficulty must be greater than 0" + + +class ABDataset(ProceduralDataset): + """Generates A::B tasks, as described by @VictorTaelin [here](https://x.com/VictorTaelin/status/1776096481704804789)""" + + def __init__(self, config: ABConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single AB task + + Returns: + dict with keys: + - question: str, the task description with AB program + - answer: str, the result of this AB program ABI execution + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + + while True: + initial_program = generate_program(self.config.length, rng) + steps, non_halting = compute_steps(initial_program) + if not non_halting: + break + + # Via: + # https://x.com/VictorTaelin/status/1776248021858111542 + # https://gist.github.com/VictorTaelin/e514844f4df9e5f182b28e5a07e44b17 + prompt = f""" +A::B is a system with 4 tokens: `A#`, `#A`, `B#` and `#B`. + +An A::B program is a sequence of tokens. Example: + + B# A# #B #A B# + +To *compute* a program, we must rewrite neighbor tokens, using the rules: + + A# #A ... becomes ... nothing + A# #B ... becomes ... #B A# + B# #A ... becomes ... #A B# + B# #B ... becomes ... nothing + +In other words, whenever two neighbor tokens have their '#' facing each-other, +they must be rewritten according to the corresponding rule. For example, the +first example shown here is computed as: + + B# A# #B #A B# = + B# #B A# #A B# = + A# #A B# = + B# + +The steps were: +1. We replaced `A# #B` by `#B A#`. +2. We replaced `B# #B` by nothing. +3. We replaced `A# #A` by nothing. +The final result was just `B#`. + +Now, consider the following program: + +{' '.join(initial_program)} + +Return the final state of the program. +""" + + return { + "question": prompt, + "answer": " ".join(steps[-1]), + "metadata": {}, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Determine if the solution provided solves the AB task. + + The function awards 1.0 for a correct answer. + + Args: + answer (Optional[str]): The user's answer. + entry (Dict[str, any]): The original dataset entry containing the correct answer. + + Returns: + float: The computed score between 0.0 and 1.0. + """ + + if answer == None: + return 0.0 + if answer != entry["answer"]: + return 0.01 + else: + return 1.0 # Yay + + +# Register the dataset +register_dataset("ab", ABDataset, ABConfig) diff --git a/tests/test_ab.py b/tests/test_ab.py new file mode 100644 index 00000000..caa67207 --- /dev/null +++ b/tests/test_ab.py @@ -0,0 +1,31 @@ +import pytest + +from reasoning_gym.algorithmic.ab import ABConfig, ABDataset + + +def test_zebra_puzzles(): + """Test basic properties and solution of generated items""" + config = ABConfig(seed=42, size=10, length=5) + dataset = ABDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + config = ABConfig(seed=42, size=10, length=15) + dataset = ABDataset(config) + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + config = ABConfig(seed=42, size=10, length=25) + dataset = ABDataset(config) + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0