diff --git a/reasoning_gym/cognition/sequences.py b/reasoning_gym/cognition/sequences.py index 8dfb4649..c1d6d4f0 100644 --- a/reasoning_gym/cognition/sequences.py +++ b/reasoning_gym/cognition/sequences.py @@ -40,9 +40,10 @@ class SequenceConfig: class PatternRule: """Represents a composable sequence pattern rule""" - def __init__(self, operations: List[Operation], parameters: List[int]): + def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None): self.operations = operations self.parameters = parameters + self.subrules = subrules or [] def apply(self, sequence: List[int], position: int) -> int: """Apply the rule to generate the next number""" @@ -62,9 +63,20 @@ class PatternRule: elif op == Operation.PREV_PLUS: if position > 0: result += sequence[position - 1] + elif op == Operation.COMPOSE: + # Apply each subrule in sequence + temp_sequence = sequence[:position + 1] + temp_sequence[-1] = result # Use current result as input + for subrule in self.subrules: + result = subrule.apply(temp_sequence, position) return result + @classmethod + def compose(cls, rules: List['PatternRule']) -> 'PatternRule': + """Create a new rule that composes multiple rules together""" + return cls([Operation.COMPOSE], [0], subrules=rules) + def to_string(self) -> str: """Convert rule to human-readable string""" parts = [] diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 1d6c3a97..6d383284 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -23,11 +23,17 @@ def test_pattern_rule(): # Test simple addition rule = PatternRule([Operation.ADD], [2]) assert rule.apply([1, 3], 1) == 5 - + # Test composition rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3]) assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3 + # Test rule composition + rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number + rule2 = PatternRule([Operation.ADD], [3]) # Add 3 + composed = PatternRule.compose([rule1, rule2]) + assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3 + def test_sequence_dataset_deterministic(): """Test that dataset generates same items with same seed"""