mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add number filtering task to algorithmic package
This commit is contained in:
parent
b6161eb20e
commit
fc7eb1d5bc
2 changed files with 136 additions and 0 deletions
|
|
@ -7,12 +7,16 @@ Algorithmic tasks for training reasoning capabilities:
|
|||
"""
|
||||
|
||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
||||
from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset
|
||||
|
||||
__all__ = [
|
||||
"LetterCountingConfig",
|
||||
"LetterCountingDataset",
|
||||
"letter_counting_dataset",
|
||||
"NumberFilteringConfig",
|
||||
"NumberFilteringDataset",
|
||||
"number_filtering_dataset",
|
||||
"WordReversalConfig",
|
||||
"WordReversalDataset",
|
||||
"word_reversal_dataset"
|
||||
|
|
|
|||
132
reasoning_gym/algorithmic/number_filtering.py
Normal file
132
reasoning_gym/algorithmic/number_filtering.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Number filtering task generator"""
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@dataclass
|
||||
class NumberFilteringConfig:
|
||||
"""Configuration for number filtering task generation"""
|
||||
min_numbers: int = 3 # Minimum numbers in list
|
||||
max_numbers: int = 10 # Maximum numbers in list
|
||||
min_decimals: int = 0 # Minimum decimal places
|
||||
max_decimals: int = 4 # Maximum decimal places
|
||||
min_value: float = -100.0 # Minimum number value
|
||||
max_value: float = 100.0 # Maximum number value
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers > 0, "min_numbers must be positive"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
assert self.min_decimals >= 0, "min_decimals must be non-negative"
|
||||
assert self.max_decimals >= self.min_decimals, "max_decimals must be >= min_decimals"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class NumberFilteringDataset:
|
||||
"""Generates number filtering tasks"""
|
||||
|
||||
def __init__(self, config: NumberFilteringConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.config.size
|
||||
|
||||
def __iter__(self):
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._current_idx >= self.config.size:
|
||||
raise StopIteration
|
||||
item = self[self._current_idx]
|
||||
self._current_idx += 1
|
||||
return item
|
||||
|
||||
def _format_number(self, num: float, decimals: int) -> str:
|
||||
"""Format a number with specified decimal places"""
|
||||
return f"{num:.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = []
|
||||
str_numbers = []
|
||||
|
||||
for _ in range(count):
|
||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
str_num = self._format_number(num, decimals)
|
||||
numbers.append(float(str_num)) # Convert back to simulate precision loss
|
||||
str_numbers.append(str_num)
|
||||
|
||||
return numbers, str_numbers
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single number filtering task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Generate numbers and their string representations
|
||||
numbers, str_numbers = self._generate_numbers(rng)
|
||||
|
||||
# Determine filter value between min and max of generated numbers
|
||||
min_val = min(numbers)
|
||||
max_val = max(numbers)
|
||||
filter_value = rng.uniform(min_val, max_val)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
filter_str = self._format_number(filter_value, decimals)
|
||||
filter_value = float(filter_str) # Convert back to simulate precision loss
|
||||
|
||||
# Randomly choose filter operation
|
||||
keep_larger = rng.choice([True, False])
|
||||
larger_smaller = "larger" if keep_larger else "smaller"
|
||||
keep_remove = "keep" if rng.choice([True, False]) else "remove"
|
||||
|
||||
# Apply filter based on chosen operation
|
||||
if keep_remove == "keep":
|
||||
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
|
||||
else: # remove
|
||||
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
|
||||
|
||||
# Format results as strings with original precision
|
||||
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
||||
|
||||
return {
|
||||
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
||||
f"in this list: {str_numbers}"),
|
||||
"answer": str(result_strs) if result_strs else "[]",
|
||||
"metadata": {
|
||||
"original_numbers": str_numbers,
|
||||
"filter_value": filter_str,
|
||||
"operation": f"{keep_remove}_{larger_smaller}",
|
||||
"result": result_strs
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def number_filtering_dataset(
|
||||
min_numbers: int = 3,
|
||||
max_numbers: int = 10,
|
||||
min_decimals: int = 0,
|
||||
max_decimals: int = 4,
|
||||
min_value: float = -100.0,
|
||||
max_value: float = 100.0,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> NumberFilteringDataset:
|
||||
"""Create a NumberFilteringDataset with the given configuration."""
|
||||
config = NumberFilteringConfig(
|
||||
min_numbers=min_numbers,
|
||||
max_numbers=max_numbers,
|
||||
min_decimals=min_decimals,
|
||||
max_decimals=max_decimals,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return NumberFilteringDataset(config)
|
||||
Loading…
Add table
Add a link
Reference in a new issue