feat: Add number filtering task to algorithmic package

This commit is contained in:
Andreas Koepf (aider) 2025-01-23 20:39:50 +01:00
parent b6161eb20e
commit fc7eb1d5bc
2 changed files with 136 additions and 0 deletions

View file

@ -7,12 +7,16 @@ Algorithmic tasks for training reasoning capabilities:
"""
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset
__all__ = [
"LetterCountingConfig",
"LetterCountingDataset",
"letter_counting_dataset",
"NumberFilteringConfig",
"NumberFilteringDataset",
"number_filtering_dataset",
"WordReversalConfig",
"WordReversalDataset",
"word_reversal_dataset"

View file

@ -0,0 +1,132 @@
"""Number filtering task generator"""
from dataclasses import dataclass
import random
from random import Random
from typing import List, Optional, Tuple
@dataclass
class NumberFilteringConfig:
"""Configuration for number filtering task generation"""
min_numbers: int = 3 # Minimum numbers in list
max_numbers: int = 10 # Maximum numbers in list
min_decimals: int = 0 # Minimum decimal places
max_decimals: int = 4 # Maximum decimal places
min_value: float = -100.0 # Minimum number value
max_value: float = 100.0 # Maximum number value
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert self.min_numbers > 0, "min_numbers must be positive"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
assert self.min_decimals >= 0, "min_decimals must be non-negative"
assert self.max_decimals >= self.min_decimals, "max_decimals must be >= min_decimals"
assert self.max_value > self.min_value, "max_value must be > min_value"
class NumberFilteringDataset:
"""Generates number filtering tasks"""
def __init__(self, config: NumberFilteringConfig):
self.config = config
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def _format_number(self, num: float, decimals: int) -> str:
"""Format a number with specified decimal places"""
return f"{num:.{decimals}f}"
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
"""Generate list of numbers and their string representations"""
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = []
str_numbers = []
for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
str_num = self._format_number(num, decimals)
numbers.append(float(str_num)) # Convert back to simulate precision loss
str_numbers.append(str_num)
return numbers, str_numbers
def __getitem__(self, idx: int) -> dict:
"""Generate a single number filtering task"""
rng = Random(self.seed + idx)
# Generate numbers and their string representations
numbers, str_numbers = self._generate_numbers(rng)
# Determine filter value between min and max of generated numbers
min_val = min(numbers)
max_val = max(numbers)
filter_value = rng.uniform(min_val, max_val)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
filter_str = self._format_number(filter_value, decimals)
filter_value = float(filter_str) # Convert back to simulate precision loss
# Randomly choose filter operation
keep_larger = rng.choice([True, False])
larger_smaller = "larger" if keep_larger else "smaller"
keep_remove = "keep" if rng.choice([True, False]) else "remove"
# Apply filter based on chosen operation
if keep_remove == "keep":
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
else: # remove
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
# Format results as strings with original precision
result_strs = [str_numbers[numbers.index(n)] for n in result]
return {
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
f"in this list: {str_numbers}"),
"answer": str(result_strs) if result_strs else "[]",
"metadata": {
"original_numbers": str_numbers,
"filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs
}
}
def number_filtering_dataset(
min_numbers: int = 3,
max_numbers: int = 10,
min_decimals: int = 0,
max_decimals: int = 4,
min_value: float = -100.0,
max_value: float = 100.0,
seed: Optional[int] = None,
size: int = 500,
) -> NumberFilteringDataset:
"""Create a NumberFilteringDataset with the given configuration."""
config = NumberFilteringConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_decimals=min_decimals,
max_decimals=max_decimals,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return NumberFilteringDataset(config)