diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 4be180d0..83b9ac0b 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -11,6 +11,7 @@ from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset +from .number_sorting import NumberSortingConfig, NumberSortingDataset, number_sorting_dataset from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset __all__ = [ @@ -25,6 +26,9 @@ __all__ = [ "NumberFilteringConfig", "NumberFilteringDataset", "number_filtering_dataset", + "NumberSortingConfig", + "NumberSortingDataset", + "number_sorting_dataset", "WordReversalConfig", "WordReversalDataset", "word_reversal_dataset" diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py new file mode 100644 index 00000000..8fa8fb94 --- /dev/null +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -0,0 +1,126 @@ +"""Number sorting task generator""" +from dataclasses import dataclass +import random +from random import Random +from typing import List, Optional, Tuple + +@dataclass +class NumberSortingConfig: + """Configuration for number sorting task generation""" + min_numbers: int = 3 # Minimum numbers to sort + max_numbers: int = 10 # Maximum numbers to sort + min_decimals: int = 0 # Minimum decimal places + max_decimals: int = 2 # Maximum decimal places + min_value: float = -100.0 # Minimum value + max_value: float = 100.0 # Maximum value + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self): + """Validate configuration parameters""" + assert self.min_numbers > 0, "min_numbers must be positive" + assert self.min_numbers <= self.max_numbers, "max_numbers must be >= min_numbers" + assert self.min_decimals >= 0, "min_decimals must be non-negative" + assert self.min_decimals <= self.max_decimals, "max_decimals must be >= min_decimals" + assert self.min_value < self.max_value, "max_value must be > min_value" + + +class NumberSortingDataset: + """Generates number sorting tasks""" + + def __init__(self, config: NumberSortingConfig): + self.config = config + self.config.validate() + self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _format_number(self, num: float, decimals: int) -> str: + """Format number with specified decimal places""" + formatted = f"{num:.{decimals}f}" + # Reparse to ensure exact decimal representation + return f"{float(formatted):.{decimals}f}" + + def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]: + """Generate list of numbers and their string representations""" + count = rng.randint(self.config.min_numbers, self.config.max_numbers) + decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) + + numbers = [] + number_strs = [] + + for _ in range(count): + num = rng.uniform(self.config.min_value, self.config.max_value) + num_str = self._format_number(num, decimals) + # Reparse to ensure exact value + num = float(num_str) + numbers.append(num) + number_strs.append(num_str) + + return numbers, number_strs + + def __getitem__(self, idx: int) -> dict: + """Generate a single sorting task""" + rng = Random(self.seed + idx) + + numbers, number_strs = self._generate_numbers(rng) + + # Generate both ascending and descending answers + asc_numbers = sorted(numbers) + desc_numbers = sorted(numbers, reverse=True) + + # Format answers as string lists + decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0 + asc_answer = [self._format_number(n, decimals) for n in asc_numbers] + desc_answer = [self._format_number(n, decimals) for n in desc_numbers] + + # Randomly choose ascending or descending + is_ascending = rng.choice([True, False]) + direction = "ascending" if is_ascending else "descending" + answer = asc_answer if is_ascending else desc_answer + + return { + "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", + "answer": str(answer), + "metadata": { + "original_numbers": number_strs, + "direction": direction, + "sorted_numbers": answer + } + } + + +def number_sorting_dataset( + min_numbers: int = 3, + max_numbers: int = 10, + min_decimals: int = 0, + max_decimals: int = 2, + min_value: float = -100.0, + max_value: float = 100.0, + seed: Optional[int] = None, + size: int = 500, +) -> NumberSortingDataset: + """Create a NumberSortingDataset with the given configuration.""" + config = NumberSortingConfig( + min_numbers=min_numbers, + max_numbers=max_numbers, + min_decimals=min_decimals, + max_decimals=max_decimals, + min_value=min_value, + max_value=max_value, + seed=seed, + size=size, + ) + return NumberSortingDataset(config) diff --git a/tests/test_number_sorting.py b/tests/test_number_sorting.py new file mode 100644 index 00000000..3374a79c --- /dev/null +++ b/tests/test_number_sorting.py @@ -0,0 +1,100 @@ +"""Tests for number sorting task generation""" +import pytest + +from reasoning_gym.algorithmic.number_sorting import ( + NumberSortingConfig, + NumberSortingDataset, +) + + +def test_number_sorting_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = NumberSortingConfig(min_numbers=0) + config.validate() + + with pytest.raises(AssertionError): + config = NumberSortingConfig(min_numbers=10, max_numbers=5) + config.validate() + + with pytest.raises(AssertionError): + config = NumberSortingConfig(min_decimals=-1) + config.validate() + + with pytest.raises(AssertionError): + config = NumberSortingConfig(min_value=100, max_value=0) + config.validate() + + +def test_number_sorting_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = NumberSortingConfig(seed=42, size=10) + dataset1 = NumberSortingDataset(config) + dataset2 = NumberSortingDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_number_sorting_dataset_items(): + """Test basic properties of generated items""" + config = NumberSortingConfig( + min_numbers=3, + max_numbers=6, + min_decimals=1, + max_decimals=3, + min_value=-10.0, + max_value=10.0, + size=10, + seed=42 + ) + dataset = NumberSortingDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "original_numbers" in item["metadata"] + assert "direction" in item["metadata"] + assert "sorted_numbers" in item["metadata"] + + # Verify number count constraints + numbers = item["metadata"]["original_numbers"] + assert len(numbers) >= config.min_numbers + assert len(numbers) <= config.max_numbers + + # Verify decimal places + for num in numbers: + decimal_places = len(num.split('.')[-1]) if '.' in num else 0 + assert decimal_places >= config.min_decimals + assert decimal_places <= config.max_decimals + + # Verify value range + for num in numbers: + value = float(num) + assert config.min_value <= value <= config.max_value + + # Verify sorting + direction = item["metadata"]["direction"] + sorted_numbers = [float(x) for x in eval(item["answer"])] + if direction == "ascending": + assert sorted_numbers == sorted(sorted_numbers) + else: + assert sorted_numbers == sorted(sorted_numbers, reverse=True) + + +def test_number_sorting_dataset_iteration(): + """Test that iteration respects dataset size""" + config = NumberSortingConfig(size=5, seed=42) + dataset = NumberSortingDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset)