""" Cognition tasks for training reasoning capabilities: - Pattern recognition - Sequence completion - Logical reasoning - Working memory """ from .sequences import SequenceDataset, SequenceConfig, sequence_dataset __all__ = ["SequenceDataset", "SequenceConfig", "sequence_dataset"] """ Cognition tasks for training reasoning capabilities: - Pattern recognition - Sequence completion - Logical reasoning - Working memory """ __all__ = [] from dataclasses import dataclass from enum import Enum from random import Random from typing import Optional, List class Operation(Enum): """Basic mathematical operations that can be composed""" ADD = "+" MULTIPLY = "*" SQUARE = "^2" DOUBLE = "*2" HALF = "/2" PREV_PLUS = "prev+" # Add previous number ALTERNATE = "alt" # Alternate between operations COMPOSE = "compose" # Compose two operations @dataclass class SequenceConfig: """Configuration for sequence generation""" min_terms: int = 4 # Minimum visible terms max_terms: int = 8 # Maximum visible terms min_value: int = -100 # Minimum allowed number max_value: int = 100 # Maximum allowed number max_complexity: int = 3 # Maximum number of operations to combine seed: Optional[int] = None size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" assert self.min_terms >= 4, "need at least 4 terms to establish pattern" assert self.max_terms >= self.min_terms assert self.max_value > self.min_value assert self.max_complexity >= 1 class PatternRule: """Represents a composable sequence pattern rule""" def __init__(self, operations: List[Operation], parameters: List[int]): self.operations = operations self.parameters = parameters def apply(self, sequence: List[int], position: int) -> int: """Apply the rule to generate the next number""" result = sequence[position] # Start with current number for op, param in zip(self.operations, self.parameters): if op == Operation.ADD: result += param elif op == Operation.MULTIPLY: result *= param elif op == Operation.SQUARE: result = result * result elif op == Operation.DOUBLE: result *= 2 elif op == Operation.HALF: result //= 2 # Integer division elif op == Operation.PREV_PLUS: if position > 0: result += sequence[position - 1] return result def to_string(self) -> str: """Convert rule to human-readable string""" parts = [] for op, param in zip(self.operations, self.parameters): if op == Operation.ADD: parts.append(f"add {param}") elif op == Operation.MULTIPLY: parts.append(f"multiply by {param}") elif op == Operation.SQUARE: parts.append("square") elif op == Operation.DOUBLE: parts.append("double") elif op == Operation.HALF: parts.append("halve") elif op == Operation.PREV_PLUS: parts.append("add previous") return " then ".join(parts) class PatternGenerator: """Generates new pattern rules with configurable complexity""" def __init__(self, rng: Random, complexity: int = 1): self.rng = rng self.complexity = complexity def generate_rule(self) -> PatternRule: """Generate a new pattern rule""" operations = [] parameters = [] # Number of operations based on complexity num_ops = self.rng.randint(1, self.complexity + 1) for _ in range(num_ops): # Pick random operation op = self.rng.choice(list(Operation)) operations.append(op) # Generate appropriate parameter if op in [Operation.ADD, Operation.MULTIPLY]: param = self.rng.randint(-10, 10) while param == 0: # Avoid trivial operations param = self.rng.randint(-10, 10) parameters.append(param) else: parameters.append(0) # Some operations don't need parameters return PatternRule(operations, parameters) def is_interesting(self, sequence: List[int], max_value: int = 1000) -> bool: """Check if sequence is interesting enough""" if not sequence: return False # Avoid too large numbers if any(abs(x) > max_value for x in sequence): return False # Avoid constant sequences if len(set(sequence)) == 1: return False # Avoid simple arithmetic progressions if complexity > 1 if self.complexity > 1: diffs = [sequence[i+1] - sequence[i] for i in range(len(sequence)-1)] if len(set(diffs)) == 1: return False return True class SequenceDataset: """Generates number sequence completion tasks with dynamic pattern generation""" def __init__(self, config: SequenceConfig): self.config = config self.config.validate() self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) def __len__(self) -> int: return self.config.size def __iter__(self): """Make the dataset iterable""" self._current_idx = 0 return self def __next__(self): """Get next item in iteration""" if self._current_idx >= self.config.size: raise StopIteration item = self[self._current_idx] self._current_idx += 1 return item def __getitem__(self, idx: int) -> dict: """Generate a sequence task with a newly generated pattern""" rng = Random(self.seed + idx) # Create pattern generator with random complexity complexity = rng.randint(1, self.config.max_complexity) generator = PatternGenerator(rng, complexity) # Generate pattern rule and sequence max_attempts = 10 for _ in range(max_attempts): rule = generator.generate_rule() # Generate initial terms num_terms = rng.randint(self.config.min_terms, self.config.max_terms) sequence = [rng.randint(-10, 10)] # Start with random number # Generate remaining terms try: for i in range(1, num_terms + 1): # +1 for answer next_term = rule.apply(sequence, i) sequence.append(next_term) if generator.is_interesting(sequence): break except (OverflowError, ZeroDivisionError): continue else: # If we couldn't generate an interesting sequence, fall back to simple addition rule = PatternRule([Operation.ADD], [2]) sequence = [i * 2 for i in range(num_terms + 1)] visible_terms = sequence[:-1] # Last term is the answer return { "question": ", ".join(map(str, visible_terms)) + ", ?", "answer": str(sequence[-1]), "metadata": { "rule": rule.to_string(), "complexity": complexity, "sequence": sequence } } def sequence_dataset( min_terms: int = 4, max_terms: int = 8, min_value: int = -100, max_value: int = 100, max_complexity: int = 3, seed: Optional[int] = None, size: int = 500, ) -> SequenceDataset: """Create a SequenceDataset with the given configuration.""" config = SequenceConfig( min_terms=min_terms, max_terms=max_terms, min_value=min_value, max_value=max_value, max_complexity=max_complexity, seed=seed, size=size, ) return SequenceDataset(config) import pytest from reasoning_gym.cognition.sequences import ( SequenceDataset, SequenceConfig, Operation, PatternRule, PatternGenerator ) def test_sequence_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): config = SequenceConfig(min_terms=3) # Too few terms config.validate() with pytest.raises(AssertionError): config = SequenceConfig(min_terms=6, max_terms=5) config.validate() with pytest.raises(AssertionError): config = SequenceConfig(min_value=100, max_value=0) config.validate() def test_pattern_rule(): """Test pattern rule application""" # Test simple addition rule = PatternRule([Operation.ADD], [2]) assert rule.apply([1, 3], 1) == 5 # Test composition rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3]) assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3 def test_sequence_dataset_deterministic(): """Test that dataset generates same items with same seed""" config = SequenceConfig(seed=42, size=10) dataset1 = SequenceDataset(config) dataset2 = SequenceDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_sequence_dataset_items(): """Test basic properties of generated items""" config = SequenceConfig( min_terms=4, max_terms=6, max_complexity=2, size=50, seed=42 ) dataset = SequenceDataset(config) for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item # Verify sequence format question = item["question"] assert question.endswith(", ?") terms = [int(x) for x in question[:-3].split(", ")] assert len(terms) >= config.min_terms assert len(terms) <= config.max_terms def test_sequence_dataset_iteration(): """Test that iteration respects dataset size""" config = SequenceConfig(size=5, seed=42) dataset = SequenceDataset(config) items = list(dataset) assert len(items) == config.size # Test multiple iterations yield same items assert items == list(dataset)