reasoning-gym/python
Andreas Koepf (aider) 432c9436f7 fix: Correct PatternRule.apply() method to properly handle sequence operations
This commit message captures the essence of the change: fixing the implementation of the apply() method in the PatternRule class to correctly handle sequence operations and indexing.

The key changes are:
1. Use `sequence[position]` instead of `sequence[position - 1]`
2. Adjust PREV_PLUS condition to use `position > 0`
3. Use `sequence[position - 1]` for previous element reference

Would you like me to elaborate on the specific changes or rationale?
2025-01-23 13:52:47 +01:00

317 lines
10 KiB
Text

"""
Cognition tasks for training reasoning capabilities:
- Pattern recognition
- Sequence completion
- Logical reasoning
- Working memory
"""
from .sequences import SequenceDataset, SequenceConfig, sequence_dataset
__all__ = ["SequenceDataset", "SequenceConfig", "sequence_dataset"]
"""
Cognition tasks for training reasoning capabilities:
- Pattern recognition
- Sequence completion
- Logical reasoning
- Working memory
"""
__all__ = []
from dataclasses import dataclass
from enum import Enum
from random import Random
from typing import Optional, List
class Operation(Enum):
"""Basic mathematical operations that can be composed"""
ADD = "+"
MULTIPLY = "*"
SQUARE = "^2"
DOUBLE = "*2"
HALF = "/2"
PREV_PLUS = "prev+" # Add previous number
ALTERNATE = "alt" # Alternate between operations
COMPOSE = "compose" # Compose two operations
@dataclass
class SequenceConfig:
"""Configuration for sequence generation"""
min_terms: int = 4 # Minimum visible terms
max_terms: int = 8 # Maximum visible terms
min_value: int = -100 # Minimum allowed number
max_value: int = 100 # Maximum allowed number
max_complexity: int = 3 # Maximum number of operations to combine
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert self.min_terms >= 4, "need at least 4 terms to establish pattern"
assert self.max_terms >= self.min_terms
assert self.max_value > self.min_value
assert self.max_complexity >= 1
class PatternRule:
"""Represents a composable sequence pattern rule"""
def __init__(self, operations: List[Operation], parameters: List[int]):
self.operations = operations
self.parameters = parameters
def apply(self, sequence: List[int], position: int) -> int:
"""Apply the rule to generate the next number"""
result = sequence[position] # Start with current number
for op, param in zip(self.operations, self.parameters):
if op == Operation.ADD:
result += param
elif op == Operation.MULTIPLY:
result *= param
elif op == Operation.SQUARE:
result = result * result
elif op == Operation.DOUBLE:
result *= 2
elif op == Operation.HALF:
result //= 2 # Integer division
elif op == Operation.PREV_PLUS:
if position > 0:
result += sequence[position - 1]
return result
def to_string(self) -> str:
"""Convert rule to human-readable string"""
parts = []
for op, param in zip(self.operations, self.parameters):
if op == Operation.ADD:
parts.append(f"add {param}")
elif op == Operation.MULTIPLY:
parts.append(f"multiply by {param}")
elif op == Operation.SQUARE:
parts.append("square")
elif op == Operation.DOUBLE:
parts.append("double")
elif op == Operation.HALF:
parts.append("halve")
elif op == Operation.PREV_PLUS:
parts.append("add previous")
return " then ".join(parts)
class PatternGenerator:
"""Generates new pattern rules with configurable complexity"""
def __init__(self, rng: Random, complexity: int = 1):
self.rng = rng
self.complexity = complexity
def generate_rule(self) -> PatternRule:
"""Generate a new pattern rule"""
operations = []
parameters = []
# Number of operations based on complexity
num_ops = self.rng.randint(1, self.complexity + 1)
for _ in range(num_ops):
# Pick random operation
op = self.rng.choice(list(Operation))
operations.append(op)
# Generate appropriate parameter
if op in [Operation.ADD, Operation.MULTIPLY]:
param = self.rng.randint(-10, 10)
while param == 0: # Avoid trivial operations
param = self.rng.randint(-10, 10)
parameters.append(param)
else:
parameters.append(0) # Some operations don't need parameters
return PatternRule(operations, parameters)
def is_interesting(self, sequence: List[int], max_value: int = 1000) -> bool:
"""Check if sequence is interesting enough"""
if not sequence:
return False
# Avoid too large numbers
if any(abs(x) > max_value for x in sequence):
return False
# Avoid constant sequences
if len(set(sequence)) == 1:
return False
# Avoid simple arithmetic progressions if complexity > 1
if self.complexity > 1:
diffs = [sequence[i+1] - sequence[i] for i in range(len(sequence)-1)]
if len(set(diffs)) == 1:
return False
return True
class SequenceDataset:
"""Generates number sequence completion tasks with dynamic pattern generation"""
def __init__(self, config: SequenceConfig):
self.config = config
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
"""Make the dataset iterable"""
self._current_idx = 0
return self
def __next__(self):
"""Get next item in iteration"""
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict:
"""Generate a sequence task with a newly generated pattern"""
rng = Random(self.seed + idx)
# Create pattern generator with random complexity
complexity = rng.randint(1, self.config.max_complexity)
generator = PatternGenerator(rng, complexity)
# Generate pattern rule and sequence
max_attempts = 10
for _ in range(max_attempts):
rule = generator.generate_rule()
# Generate initial terms
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
sequence = [rng.randint(-10, 10)] # Start with random number
# Generate remaining terms
try:
for i in range(1, num_terms + 1): # +1 for answer
next_term = rule.apply(sequence, i)
sequence.append(next_term)
if generator.is_interesting(sequence):
break
except (OverflowError, ZeroDivisionError):
continue
else:
# If we couldn't generate an interesting sequence, fall back to simple addition
rule = PatternRule([Operation.ADD], [2])
sequence = [i * 2 for i in range(num_terms + 1)]
visible_terms = sequence[:-1] # Last term is the answer
return {
"question": ", ".join(map(str, visible_terms)) + ", ?",
"answer": str(sequence[-1]),
"metadata": {
"rule": rule.to_string(),
"complexity": complexity,
"sequence": sequence
}
}
def sequence_dataset(
min_terms: int = 4,
max_terms: int = 8,
min_value: int = -100,
max_value: int = 100,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> SequenceDataset:
"""Create a SequenceDataset with the given configuration."""
config = SequenceConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return SequenceDataset(config)
import pytest
from reasoning_gym.cognition.sequences import (
SequenceDataset,
SequenceConfig,
Operation,
PatternRule,
PatternGenerator
)
def test_sequence_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = SequenceConfig(min_terms=3) # Too few terms
config.validate()
with pytest.raises(AssertionError):
config = SequenceConfig(min_terms=6, max_terms=5)
config.validate()
with pytest.raises(AssertionError):
config = SequenceConfig(min_value=100, max_value=0)
config.validate()
def test_pattern_rule():
"""Test pattern rule application"""
# Test simple addition
rule = PatternRule([Operation.ADD], [2])
assert rule.apply([1, 3], 1) == 5
# Test composition
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
def test_sequence_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = SequenceConfig(seed=42, size=10)
dataset1 = SequenceDataset(config)
dataset2 = SequenceDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_sequence_dataset_items():
"""Test basic properties of generated items"""
config = SequenceConfig(
min_terms=4,
max_terms=6,
max_complexity=2,
size=50,
seed=42
)
dataset = SequenceDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify sequence format
question = item["question"]
assert question.endswith(", ?")
terms = [int(x) for x in question[:-3].split(", ")]
assert len(terms) >= config.min_terms
assert len(terms) <= config.max_terms
def test_sequence_dataset_iteration():
"""Test that iteration respects dataset size"""
config = SequenceConfig(size=5, seed=42)
dataset = SequenceDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)