From 8772041afbeb84fbe32e99ec8bd9f61fca17cd19 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Mon, 10 Feb 2025 18:58:07 +0100 Subject: [PATCH 01/11] Add attributes for curriculum Co-authored-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com> --- reasoning_gym/arithmetic/chain_sum.py | 32 ++++++ reasoning_gym/coaching/__init__.py | 14 +++ reasoning_gym/coaching/attributes.py | 73 ++++++++++++ reasoning_gym/coaching/base_curriculum.py | 108 ++++++++++++++++++ .../{coaching.py => coaching/coach.py} | 2 +- tests/test_chain_sum.py | 18 +++ 6 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 reasoning_gym/coaching/__init__.py create mode 100644 reasoning_gym/coaching/attributes.py create mode 100644 reasoning_gym/coaching/base_curriculum.py rename reasoning_gym/{coaching.py => coaching/coach.py} (99%) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 30dcb0c4..01d387a6 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -2,6 +2,7 @@ import random from dataclasses import dataclass from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset): return expression, result +class ChainSumCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ChainSumCurriculum.__name__, ChainSumConfig) + + # Define attributes + self._define_attributes( + ( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 4, 10], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + ) + + # Register the dataset register_dataset("chain_sum", ChainSum, ChainSumConfig) diff --git a/reasoning_gym/coaching/__init__.py b/reasoning_gym/coaching/__init__.py new file mode 100644 index 00000000..50683d66 --- /dev/null +++ b/reasoning_gym/coaching/__init__.py @@ -0,0 +1,14 @@ +from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition +from .base_curriculum import BaseCurriculum +from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats + +__all__ = [ + "AttributeType", + "AttributeDefinition", + "RangeAttributeDefinition", + "BaseCurriculum", + "Coach", + "ScoreBoard", + "GroupedScores", + "ScoreStats", +] diff --git a/reasoning_gym/coaching/attributes.py b/reasoning_gym/coaching/attributes.py new file mode 100644 index 00000000..33fc053b --- /dev/null +++ b/reasoning_gym/coaching/attributes.py @@ -0,0 +1,73 @@ +from collections import abc +from dataclasses import dataclass +from enum import StrEnum +from typing import Any, Optional + + +class AttributeType(StrEnum): + """Defines how attribute levels should be interpreted""" + + STATIC = "static" # Each level is independent + UBOUND = "ubound" # Each level is an upper bound + APPEND = "append" # Each level includes all previous levels + + +@dataclass(kw_only=True) +class AttributeDefinition: + name: str + levels: list + default_level: int + description: Optional[str] = None + attr_type: AttributeType = AttributeType.STATIC # Default to static + min_value: Optional[int | float] = None # Minimum value for numeric attributes + + def validate_level(self, level: int, curriculum: str) -> None: + """ + Validate that a level is valid for an attribute. + Args: + level: Level to validate + curriculum: Name of the curriculum + Raises: + ValueError: If level is invalid + """ + # TODO: if > set as [-1], if <0 set as [0] + if not 0 <= level < len(self.levels): + raise ValueError( + f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. " + f"Must be between 0 and {len(self.levels)-1}" + ) + + def get_level_value(self, level: int, curriculum: str) -> Any: + """ + Get the value for an attribute at a specific level based on its type. + Args: + attr: The attribute definition + level: Level to get value for + Returns: + Value for the attribute based on its level and type + """ + if self.attr_type == AttributeType.STATIC: + return self.levels[level] + elif self.attr_type == AttributeType.UBOUND: + return self.levels[level] + elif self.attr_type == AttributeType.APPEND: + return self.levels[: level + 1] + + raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'") + + +@dataclass(kw_only=True) +class ScalarAttributeDefinition(AttributeDefinition): + field_name: str + + +@dataclass(kw_only=True) +class RangeAttributeDefinition(AttributeDefinition): + lower_field_name: str + upper_field_name: str + + def get_level_value(self, level: int, curriculum: str) -> Any: + v = super().get_level_value(level, curriculum) + if not isinstance(v, abc.Iterable): + return [v] + return v diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py new file mode 100644 index 00000000..8d619869 --- /dev/null +++ b/reasoning_gym/coaching/base_curriculum.py @@ -0,0 +1,108 @@ +from typing import Any, Iterable, Optional + +from ..factory import ConfigT +from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition + + +class BaseCurriculum: + def __init__(self, name: str, config_cls: ConfigT): + self.name = name + self._config_cls = config_cls + self._attributes: dict[str, AttributeDefinition] = {} + self._current_levels: dict[str, int] = {} + + def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT: + config_args = defaults.copy() if defaults is not None else {} + for attr in self._attributes.values(): + if isinstance(attr, RangeAttributeDefinition): + vals = self.get_attr_value(attr.name) + config_args[attr.lower_field_name] = min(vals) + config_args[attr.upper_field_name] = max(vals) + elif isinstance(attr, ScalarAttributeDefinition): + val = self.get_attr_value(attr.name) + config_args[attr.field_name] = val + print(config_args) + return self._config_cls(**config_args) + + @property + def attributes(self) -> dict[str, AttributeDefinition]: + """Get the curriculum's attributes""" + return self._attributes + + def get_attribute(self, attr_name: str) -> AttributeDefinition: + if attr_name not in self._attributes: + raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist") + return self._attributes[attr_name] + + def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None: + for attr in attrs: + if attr.name in self.attributes: + raise RuntimeError(f"Attribute with name {attr.name} is already defined.") + self.attributes[attr.name] = attr + + def get_attr_level(self, attr_name: str) -> int: + """ + Get the current level for an attribute. + Args: + attr_name: Name of the attribute + Returns: + Current level index for the attribute + """ + attr = self.get_attribute(attr_name) + return self._current_levels.get(attr_name, attr.default_level) + + def get_attr_value(self, attr_name: str) -> Any: + """ + Get the current value for an attribute based on its level. + Args: + attr_name: Name of the attribute + Returns: + Current value for the attribute based on its level and type + """ + attr = self.get_attribute(attr_name) + level = self.get_attr_level(attr_name) + return attr.get_level_value(level, curriculum=self.name) + + def set_attr_level(self, attr_name: str, level: int) -> None: + """ + Set the level for an attribute. + Args: + attr_name: Name of the attribute + level: New level index + """ + attr = self.get_attribute(attr_name) + attr.validate_level(level, curriculum=self.name) + self._current_levels[attr_name] = level + + def increment_attr_level(self, attr_name: str) -> bool: + """ + Increment the level of an attribute if possible. + Args: + attr_name: Name of the attribute to increment + Returns: + bool: True if level was incremented, False if already at max level + Raises: + KeyError: If attribute doesn't exist + """ + attr = self.get_attribute(attr_name) + current_level = self.get_attr_level(attr_name) + if current_level < len(attr.levels) - 1: + self.set_attr_level(attr_name, current_level + 1) + return True + return False + + def decrement_attr_level(self, attr_name: str) -> bool: + """ + Decrement the level of an attribute if possible. + Args: + attr_name: Name of the attribute to decrement + Returns: + bool: True if level was decremented, False if already at min level + Raises: + KeyError: If attribute doesn't exist + """ + current_level = self.get_attr_level(attr_name) + if current_level > 0: + self.set_attr_level(attr_name, current_level - 1) + return True + return False diff --git a/reasoning_gym/coaching.py b/reasoning_gym/coaching/coach.py similarity index 99% rename from reasoning_gym/coaching.py rename to reasoning_gym/coaching/coach.py index ad14077a..eeeab5d1 100644 --- a/reasoning_gym/coaching.py +++ b/reasoning_gym/coaching/coach.py @@ -8,7 +8,7 @@ from pathlib import Path from statistics import mean, stdev from typing import Any, Dict, List, Optional, Tuple, Union -from .dataset import ProceduralDataset +from ..dataset import ProceduralDataset @dataclass diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index c1ddf641..ed90429f 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,6 +1,7 @@ import pytest from reasoning_gym.arithmetic import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum def test_chain_sum_config_validation(): @@ -127,3 +128,20 @@ def test_chain_sum_iteration(): first_items = list(dataset) second_items = list(dataset) assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_chain_sum_curriculum(): + c = ChainSumCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ChainSumConfig = c.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + + c.increment_attr_level("num_terms") + c.increment_attr_level("num_digits") + + config2 = c.generate_configuration() From 074f46780d8174535d6eff756bf01bdabbc72e1e Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Mon, 10 Feb 2025 22:09:18 +0100 Subject: [PATCH 02/11] add chain_sum curriculum unit test --- tests/test_chain_sum.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index ed90429f..36b0185c 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -131,17 +131,27 @@ def test_chain_sum_iteration(): def test_chain_sum_curriculum(): - c = ChainSumCurriculum() + curriculum = ChainSumCurriculum() base_value = {"size": 150, "seed": 1} - base_cfg: ChainSumConfig = c.generate_configuration(base_value) + base_cfg: ChainSumConfig = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 - c.increment_attr_level("num_terms") - c.increment_attr_level("num_digits") + # test incrementing attribute levels for num_terms & num_digits attributes + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") - config2 = c.generate_configuration() + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 + + # test decrementing attribute level for num_digits again + curriculum.decrement_attr_level("num_digits") + + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1 + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3 From 852ddfcea3647d9ee350f1de2d3ddc06f0911132 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 12:53:13 +0100 Subject: [PATCH 03/11] add dice dataset --- reasoning_gym/arithmetic/__init__.py | 3 + reasoning_gym/arithmetic/dice.py | 154 +++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 reasoning_gym/arithmetic/dice.py diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 05d321da..94de4880 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -14,6 +14,7 @@ from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset +from .dice import DiceConfig, DiceDataset __all__ = [ "BasicArithmeticDataset", @@ -38,4 +39,6 @@ __all__ = [ "TimeIntervalsDataset", "CountBitsConfig", "CountBitsDataset", + "DiceConfig", + "DiceDataset", ] diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py new file mode 100644 index 00000000..a3203dd9 --- /dev/null +++ b/reasoning_gym/arithmetic/dice.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional +from math import gcd +from functools import reduce + +from ..factory import ProceduralDataset, register_dataset + + +def compute_probability(dice, target): + """ + Computes the probability of rolling a total of at least `target` + when rolling dice specified in the list `dice`. Each element in dice + is the number of sides on that die. The computation is done via dynamic programming. + Returns the probability as a fraction (numerator, denominator) and as a float. + """ + # dp[i][s] = number of ways to get sum s using the first i dice. + # We use only one dictionary for the current dp state. + dp = {0: 1} + for sides in dice: + new_dp = {} + for current_sum, count in dp.items(): + # Each die gives a number from 1 to sides. + for face in range(1, sides + 1): + new_sum = current_sum + face + new_dp[new_sum] = new_dp.get(new_sum, 0) + count + dp = new_dp + + total_outcomes = reduce(lambda a, b: a * b, dice, 1) + ways = sum(count for s, count in dp.items() if s >= target) + + # Simplify the fraction (ways / total_outcomes) + def simplify(n, d): + common = gcd(n, d) + return n // common, d // common + + frac = simplify(ways, total_outcomes) + return frac, ways / total_outcomes + +def generate_puzzle(num_dice, max_dice_size, rng): + """ + Generates a puzzle: + - It forces one die to have max_dice_size. + - The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1. + - The dice are then shuffled. + - The target total is chosen roughly in the middle (but you can adjust the method). + + It then computes the probability of rolling a total at least the target. + Finally, it prints out the puzzle statement and the answer. + """ + + # Guarantee one die is the maximum. + dice = [max_dice_size] + for _ in range(num_dice - 1): + # Choose a die size randomly from 2 up to max_dice_size-1. + # (If max_dice_size == 2 then all dice are 2-sided.) + if max_dice_size > 2: + die = rng.randint(2, max_dice_size - 1) + else: + die = 2 + dice.append(die) + + # Optionally, sort dice in descending order (as is common in puzzles) + dice.sort(reverse=True) + + # Compute minimum and maximum possible totals. + min_total = num_dice # each die gives at least 1 + max_total = sum(dice) + + # Choose a target total. For an interesting puzzle, + # we choose a target somewhere in the middle third of the range. + low_target = min_total + (max_total - min_total) // 3 + high_target = min_total + 2 * (max_total - min_total) // 3 + target = rng.randint(low_target, high_target) + + # Compute probability. + (num, den), prob = compute_probability(dice, target) + + # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. + dice_str = ", ".join(f"1d{s}" for s in dice) + + # Return the puzzle. + return { + 'dice_str': dice_str, + 'target': target, + 'num': num, + 'den': den + } + + +@dataclass +class DiceConfig: + """Configuration for dice puzzle generation""" + + num_dice: int = 4 + max_dice_size: int = 20 + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters""" + assert self.num_dice >= 1, "num_dice must be gte 1" + assert self.max_dice_size >= 2, "max_dice_size must be gte 2" + + +class DiceDataset(ProceduralDataset): + """Generates Dice-based puzzles with configurable parameters""" + + def __init__(self, config: DiceConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Dice task + + Returns: + dict with keys: + - question: str, the task description + - answer: str, a solution string + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng) + puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? Please respond with a reduced fraction representing the probability [ex., 1/60]." + answer_str = f"{puzzle['num']}/{puzzle['den']}" + + return { + "question": puzzle_str, + "answer": answer_str, + "metadata": { + }, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Determine if the solution provided solves the Dice task. + + The function awards 1.0 for a correct answer. + + Args: + answer (Optional[str]): The user's answer. + entry (Dict[str, any]): The original dataset entry containing the correct answer. + + Returns: + float: The computed score between 0.0 and 1.0. + """ + + if answer == None: + return 0.0 + if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): + return 0.01 + else: + return 1.0 # Yay + + +register_dataset("dice", DiceDataset, DiceConfig) From 945207da43bfa9ebc473301eea4c7f9fb70af758 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 12:54:23 +0100 Subject: [PATCH 04/11] fmt --- reasoning_gym/arithmetic/__init__.py | 2 +- reasoning_gym/arithmetic/dice.py | 31 ++++++++++++---------------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 94de4880..cc94bf56 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSum, ChainSumConfig from .count_bits import CountBitsConfig, CountBitsDataset +from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig @@ -14,7 +15,6 @@ from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset -from .dice import DiceConfig, DiceDataset __all__ = [ "BasicArithmeticDataset", diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index a3203dd9..6fd464b9 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -1,8 +1,8 @@ from dataclasses import dataclass +from functools import reduce +from math import gcd from random import Random from typing import Dict, Optional -from math import gcd -from functools import reduce from ..factory import ProceduralDataset, register_dataset @@ -37,6 +37,7 @@ def compute_probability(dice, target): frac = simplify(ways, total_outcomes) return frac, ways / total_outcomes + def generate_puzzle(num_dice, max_dice_size, rng): """ Generates a puzzle: @@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng): - The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1. - The dice are then shuffled. - The target total is chosen roughly in the middle (but you can adjust the method). - + It then computes the probability of rolling a total at least the target. Finally, it prints out the puzzle statement and the answer. """ - + # Guarantee one die is the maximum. dice = [max_dice_size] for _ in range(num_dice - 1): @@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng): else: die = 2 dice.append(die) - + # Optionally, sort dice in descending order (as is common in puzzles) dice.sort(reverse=True) - + # Compute minimum and maximum possible totals. min_total = num_dice # each die gives at least 1 max_total = sum(dice) - + # Choose a target total. For an interesting puzzle, # we choose a target somewhere in the middle third of the range. low_target = min_total + (max_total - min_total) // 3 high_target = min_total + 2 * (max_total - min_total) // 3 target = rng.randint(low_target, high_target) - + # Compute probability. (num, den), prob = compute_probability(dice, target) - + # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. dice_str = ", ".join(f"1d{s}" for s in dice) - + # Return the puzzle. - return { - 'dice_str': dice_str, - 'target': target, - 'num': num, - 'den': den - } + return {"dice_str": dice_str, "target": target, "num": num, "den": den} @dataclass @@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset): return { "question": puzzle_str, "answer": answer_str, - "metadata": { - }, + "metadata": {}, } def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: From f7cf015e3b6dd5b7a8022b9e10893b78dcbe1ce2 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 12:59:16 +0100 Subject: [PATCH 05/11] commit test --- tests/test_dice.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_dice.py diff --git a/tests/test_dice.py b/tests/test_dice.py new file mode 100644 index 00000000..a351fdb6 --- /dev/null +++ b/tests/test_dice.py @@ -0,0 +1,35 @@ +import pytest + +from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset + + +def test_dice(): + """Test basic properties and solution of generated items""" + config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24) + dataset = DiceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # Easy + config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2) + dataset = DiceDataset(config) + + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # Hard + config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40) + dataset = DiceDataset(config) + + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 \ No newline at end of file From b208dc664e784fde5701e9c8c58b5a7bc8c0de70 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 13:00:12 +0100 Subject: [PATCH 06/11] lint again --- tests/test_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dice.py b/tests/test_dice.py index a351fdb6..8a3bd991 100644 --- a/tests/test_dice.py +++ b/tests/test_dice.py @@ -32,4 +32,4 @@ def test_dice(): for item in dataset: assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer=None, entry=item) == 0.0 \ No newline at end of file + assert dataset.score_answer(answer=None, entry=item) == 0.0 From 21b845ebefc1d283a3656a777d36fab7c1a5b716 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Tue, 11 Feb 2025 13:54:54 +0100 Subject: [PATCH 07/11] simplify rotate method --- reasoning_gym/algorithmic/rotate_matrix.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/reasoning_gym/algorithmic/rotate_matrix.py b/reasoning_gym/algorithmic/rotate_matrix.py index 4fdf651e..adeaa47c 100644 --- a/reasoning_gym/algorithmic/rotate_matrix.py +++ b/reasoning_gym/algorithmic/rotate_matrix.py @@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset): matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] return matrix + def _rot90(self, matrix: list[list[int]]) -> list[list[int]]: + """quarter clockwise rotation""" + return [list(row) for row in zip(*matrix[::-1])] + def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]: """Rotate the matrix K times by 90 degrees clockwise""" num_rotations %= 4 - n = len(matrix) output = deepcopy(matrix) - for _ in range(num_rotations): - for l in range(n // 2): - for i in range(l, n - 1 - l): - (output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = ( - output[n - 1 - i][l], - output[l][i], - output[i][n - 1 - l], - output[n - 1 - l][n - 1 - i], - ) - + output = self._rot90(output) return output def _matrix_to_str(self, matrix: list[list[int]]) -> str: From c2fb8bb6cc441559f28260855096380306b6f8ef Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 13:56:27 +0100 Subject: [PATCH 08/11] add rectangle count dataset --- reasoning_gym/cognition/__init__.py | 3 + reasoning_gym/cognition/rectangle_count.py | 135 +++++++++++++++++++++ tests/test_rectangle_count.py | 19 +++ 3 files changed, 157 insertions(+) create mode 100644 reasoning_gym/cognition/rectangle_count.py create mode 100644 tests/test_rectangle_count.py diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 5c0b7f8b..473fee97 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities. from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset from .number_sequences import NumberSequenceConfig, NumberSequenceDataset +from .rectangle_count import RectangleCountConfig, RectangleCountDataset from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset __all__ = [ @@ -16,4 +17,6 @@ __all__ = [ "NumberSequenceDataset", "RubiksCubeConfig", "RubiksCubeDataset", + "RectangleCountConfig", + "RectangleCountDataset", ] diff --git a/reasoning_gym/cognition/rectangle_count.py b/reasoning_gym/cognition/rectangle_count.py new file mode 100644 index 00000000..959dc25f --- /dev/null +++ b/reasoning_gym/cognition/rectangle_count.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + + +def draw_rectangles_with_overlap(n, width, height, rng): + # Create a grid that holds a count of how many times a cell is drawn. + grid = [[0 for _ in range(width)] for _ in range(height)] + rectangles = [] + + max_attempts = 100000 # Prevent infinite loops in case of a crowded grid + attempts = 0 + + while len(rectangles) < n and attempts < max_attempts: + attempts += 1 + # Ensure minimum width and height of 3. + # For a rectangle to be at least 3 cells wide, right must be at least left + 2. + # Similarly, bottom must be at least top + 2. + left = rng.randint(0, width - 3) + right = rng.randint(left + 2, width - 1) + top = rng.randint(0, height - 3) + bottom = rng.randint(top + 2, height - 1) + + # Prepare a list of all the cells that would be updated. + cells_to_update = [] + + # Top edge: + for col in range(left, right + 1): + cells_to_update.append((top, col)) + # Bottom edge: + for col in range(left, right + 1): + cells_to_update.append((bottom, col)) + # Left edge (excluding corners already drawn): + for row in range(top + 1, bottom): + cells_to_update.append((row, left)) + # Right edge (excluding corners already drawn): + for row in range(top + 1, bottom): + cells_to_update.append((row, right)) + + # Check if drawing this rectangle would cause any cell to exceed a count of 2. + conflict = False + for r, c in cells_to_update: + if grid[r][c] >= 2: + conflict = True + break + if conflict: + continue # Skip this rectangle candidate + + # No conflict: update the grid counts. + for r, c in cells_to_update: + grid[r][c] += 1 + + # Save the rectangle (stored as (left, right, top, bottom)). + rectangles.append((left, right, top, bottom)) + + if len(rectangles) < n: + print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.") + + # Print the grid. + # Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits. + lines = "" + for row in grid: + line = "".join(" " if count == 0 else ("#" if count == 1 else "█") for count in row) + lines = lines + line + "\n" + return lines, len(rectangles) + + +@dataclass +class RectangleCountConfig: + """Configuration for RectangleCount puzzle generation""" + + max_rectangles: int = 10 + width: int = 80 + height: int = 80 + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters""" + assert self.width >= 10, "width must be gte 10" + assert self.height >= 10, "height must be gte 10" + + +class RectangleCountDataset(ProceduralDataset): + """Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters""" + + def __init__(self, config: RectangleCountConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single RectangleCount task + + Returns: + dict with keys: + - question: str, the task description + - answer: str, a solution string + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + + target = rng.randint(1, self.config.max_rectangles) + puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng) + + puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. \n\n {puzzle}" + + return { + "question": puzz, + "answer": str(answer), + "metadata": {}, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Determine if the solution provided solves the RectangleCount task. + + The function awards 1.0 for a correct answer. + + Args: + answer (Optional[str]): The user's answer. + entry (Dict[str, any]): The original dataset entry containing the correct answer. + + Returns: + float: The computed score between 0.0 and 1.0. + """ + + if answer == None: + return 0.0 + if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): + return 0.01 + else: + return 1.0 # Yay + + +register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig) diff --git a/tests/test_rectangle_count.py b/tests/test_rectangle_count.py new file mode 100644 index 00000000..fcbb09f0 --- /dev/null +++ b/tests/test_rectangle_count.py @@ -0,0 +1,19 @@ +import pytest + +from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset + + +def test_dice(): + """Test basic properties and solution of generated items""" + config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40) + dataset = RectangleCountDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 From 1dc0a29eae810bfec9dacda9f003709a92a30e1d Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Tue, 11 Feb 2025 14:44:38 +0100 Subject: [PATCH 09/11] count primes --- reasoning_gym/algorithmic/__init__.py | 3 + reasoning_gym/algorithmic/count_primes.py | 60 +++++++++++++++ tests/test_count_primes.py | 89 +++++++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 reasoning_gym/algorithmic/count_primes.py create mode 100644 tests/test_count_primes.py diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index ad76199f..dd3c6758 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,6 +25,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .count_primes import CountPrimesConfig, CountPrimesDataset __all__ = [ "SpellBackwardConfig", @@ -66,4 +67,6 @@ __all__ = [ "ManipulateMatrixDataset", "BinaryMatrixConfig", "BinaryMatrixDataset", + "CountPrimesConfig", + "CountPrimesDataset", ] diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py new file mode 100644 index 00000000..3335371c --- /dev/null +++ b/reasoning_gym/algorithmic/count_primes.py @@ -0,0 +1,60 @@ +"""Count prime numbers in a given interval. + +Solution obtained with Sieve of Eratosthenes: +https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes +""" + +import math +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?""" + +@dataclass +class CountPrimesConfig: + """Configuration for Count Primes dataset generation""" + + max_n: int = 10_000 # Upper bound for the interval + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_n, "max_n must be at least 1" + +class CountPrimesDataset(ProceduralDataset): + """Generates Count Primes exercises with configurable difficulty""" + + def __init__(self, config: CountPrimesConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.primes = self._get_primes(config.max_n + 1) + + def _get_primes(self, n: int) -> list[bool]: + if n <= 1: + return [] + primes = [True] * n + primes[0] = primes[1] = False + for i in range(2, int(math.sqrt(n))+1): + if primes[i]: + for j in range(2*i, n, i): + primes[j] = False + return primes + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Primes question""" + rng = Random(self.seed + idx) + start = rng.randint(1, self.config.max_n) + end = rng.randint(start, self.config.max_n) + primes = self.primes[start:end+1] + answer = sum(primes) + return { + "question": QUESTION_TEMPLATE.format(start=start, end=end), + "answer": str(answer), + "metadata": {"start": start, "end": end, "primes": primes, "solution": answer}, + } + +register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig) diff --git a/tests/test_count_primes.py b/tests/test_count_primes.py new file mode 100644 index 00000000..d88981e9 --- /dev/null +++ b/tests/test_count_primes.py @@ -0,0 +1,89 @@ +"""Tests for Count Primes questions generation""" + +import pytest + +from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset + + +def test_count_primes_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CountPrimesConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CountPrimesConfig(max_n=0) # Zero not allowed + config.validate() + + + +def test_count_primes_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = CountPrimesConfig(seed=42, size=10) + dataset1 = CountPrimesDataset(config) + dataset2 = CountPrimesDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_count_primes_dataset_items(): + """Test basic properties of generated items""" + config = CountPrimesConfig(max_n=10, size=10, seed=42) + dataset = CountPrimesDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "start" in item["metadata"] + assert "end" in item["metadata"] + assert "primes" in item["metadata"] + assert "solution" in item["metadata"] + + start = item["metadata"]["start"] + end = item["metadata"]["end"] + primes = item["metadata"]["primes"] + + assert start <= end + assert len(primes) <= end - start + 1 + + +def test_count_primes_dataset_iteration(): + """Test that iteration respects dataset size""" + config = CountPrimesConfig(size=5, seed=42) + dataset = CountPrimesDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_count_primes_answer(): + """Test the _get_primes method""" + config = CountPrimesConfig(seed=42) + dataset = CountPrimesDataset(config) + + # Base cases + assert dataset._get_primes(n=0) == [] + assert dataset._get_primes(n=1) == [] + assert dataset._get_primes(n=2) == [False, False] + + # Test primes up to 10 + primes = dataset._get_primes(n=11) + assert primes[2] == True + assert primes[3] == True + assert primes[4] == False + assert primes[5] == True + assert primes[6] == False + assert primes[7] == True + assert primes[8] == False + assert primes[9] == False + assert primes[10] == False \ No newline at end of file From 5a8ce7d2af38ddf21372c3e9fd840af12724fc7e Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Tue, 11 Feb 2025 14:44:46 +0100 Subject: [PATCH 10/11] lint --- reasoning_gym/algorithmic/__init__.py | 2 +- reasoning_gym/algorithmic/count_primes.py | 11 +++++++---- tests/test_count_primes.py | 3 +-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index dd3c6758..bf17ebe8 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -9,6 +9,7 @@ Algorithmic tasks for training reasoning capabilities: from .base_conversion import BaseConversionConfig, BaseConversionDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset +from .count_primes import CountPrimesConfig, CountPrimesDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset @@ -25,7 +26,6 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .count_primes import CountPrimesConfig, CountPrimesDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py index 3335371c..0a553c7f 100644 --- a/reasoning_gym/algorithmic/count_primes.py +++ b/reasoning_gym/algorithmic/count_primes.py @@ -13,6 +13,7 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?""" + @dataclass class CountPrimesConfig: """Configuration for Count Primes dataset generation""" @@ -26,21 +27,22 @@ class CountPrimesConfig: """Validate configuration parameters""" assert 1 <= self.max_n, "max_n must be at least 1" + class CountPrimesDataset(ProceduralDataset): """Generates Count Primes exercises with configurable difficulty""" def __init__(self, config: CountPrimesConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.primes = self._get_primes(config.max_n + 1) - + def _get_primes(self, n: int) -> list[bool]: if n <= 1: return [] primes = [True] * n primes[0] = primes[1] = False - for i in range(2, int(math.sqrt(n))+1): + for i in range(2, int(math.sqrt(n)) + 1): if primes[i]: - for j in range(2*i, n, i): + for j in range(2 * i, n, i): primes[j] = False return primes @@ -49,7 +51,7 @@ class CountPrimesDataset(ProceduralDataset): rng = Random(self.seed + idx) start = rng.randint(1, self.config.max_n) end = rng.randint(start, self.config.max_n) - primes = self.primes[start:end+1] + primes = self.primes[start : end + 1] answer = sum(primes) return { "question": QUESTION_TEMPLATE.format(start=start, end=end), @@ -57,4 +59,5 @@ class CountPrimesDataset(ProceduralDataset): "metadata": {"start": start, "end": end, "primes": primes, "solution": answer}, } + register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig) diff --git a/tests/test_count_primes.py b/tests/test_count_primes.py index d88981e9..f131647b 100644 --- a/tests/test_count_primes.py +++ b/tests/test_count_primes.py @@ -16,7 +16,6 @@ def test_count_primes_config_validation(): config.validate() - def test_count_primes_dataset_deterministic(): """Test that dataset generates same items with same seed""" config = CountPrimesConfig(seed=42, size=10) @@ -86,4 +85,4 @@ def test_count_primes_answer(): assert primes[7] == True assert primes[8] == False assert primes[9] == False - assert primes[10] == False \ No newline at end of file + assert primes[10] == False From 0a4799d99a649c9c6a07395acce3964ef44b0413 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Tue, 11 Feb 2025 16:22:53 +0100 Subject: [PATCH 11/11] clarity --- reasoning_gym/arithmetic/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index 6fd464b9..f4ad97e9 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -116,7 +116,7 @@ class DiceDataset(ProceduralDataset): """ rng = Random(self.seed + idx) puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng) - puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? Please respond with a reduced fraction representing the probability [ex., 1/60]." + puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]." answer_str = f"{puzzle['num']}/{puzzle['den']}" return {