reasoning-gym/python

648 lines
22 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Cognition tasks for training reasoning capabilities:
- Pattern recognition
- Sequence completion
- Logical reasoning
- Working memory
"""
from .sequences import SequenceDataset, SequenceConfig, sequence_dataset
__all__ = ["SequenceDataset", "SequenceConfig", "sequence_dataset"]
"""
Cognition tasks for training reasoning capabilities:
- Pattern recognition
- Sequence completion
- Logical reasoning
- Working memory
"""
__all__ = []
from dataclasses import dataclass
from enum import Enum
from random import Random
from typing import Optional, List
class Operation(Enum):
"""Basic mathematical operations that can be composed"""
ADD = "+"
MULTIPLY = "*"
SQUARE = "^2"
DOUBLE = "*2"
HALF = "/2"
PREV_PLUS = "prev+" # Add previous number
ALTERNATE = "alt" # Alternate between operations
COMPOSE = "compose" # Compose two operations
@dataclass
class SequenceConfig:
"""Configuration for sequence generation"""
min_terms: int = 4 # Minimum visible terms
max_terms: int = 8 # Maximum visible terms
min_value: int = -100 # Minimum allowed number
max_value: int = 100 # Maximum allowed number
max_complexity: int = 3 # Maximum number of operations to combine
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert self.min_terms >= 4, "need at least 4 terms to establish pattern"
assert self.max_terms >= self.min_terms
assert self.max_value > self.min_value
assert self.max_complexity >= 1
class PatternRule:
"""Represents a composable sequence pattern rule"""
def __init__(self, operations: List[Operation], parameters: List[int]):
self.operations = operations
self.parameters = parameters
def apply(self, sequence: List[int], position: int) -> int:
"""Apply the rule to generate the next number"""
result = sequence[position] # Start with current number
for op, param in zip(self.operations, self.parameters):
if op == Operation.ADD:
result += param
elif op == Operation.MULTIPLY:
result *= param
elif op == Operation.SQUARE:
result = result * result
elif op == Operation.DOUBLE:
result *= 2
elif op == Operation.HALF:
result //= 2 # Integer division
elif op == Operation.PREV_PLUS:
if position > 0:
result += sequence[position - 1]
return result
def to_string(self) -> str:
"""Convert rule to human-readable string"""
parts = []
for op, param in zip(self.operations, self.parameters):
if op == Operation.ADD:
parts.append(f"add {param}")
elif op == Operation.MULTIPLY:
parts.append(f"multiply by {param}")
elif op == Operation.SQUARE:
parts.append("square")
elif op == Operation.DOUBLE:
parts.append("double")
elif op == Operation.HALF:
parts.append("halve")
elif op == Operation.PREV_PLUS:
parts.append("add previous")
return " then ".join(parts)
class PatternGenerator:
"""Generates new pattern rules with configurable complexity"""
def __init__(self, rng: Random, complexity: int = 1):
self.rng = rng
self.complexity = complexity
def generate_rule(self) -> PatternRule:
"""Generate a new pattern rule"""
operations = []
parameters = []
# Number of operations based on complexity
num_ops = self.rng.randint(1, self.complexity + 1)
for _ in range(num_ops):
# Pick random operation
op = self.rng.choice(list(Operation))
operations.append(op)
# Generate appropriate parameter
if op in [Operation.ADD, Operation.MULTIPLY]:
param = self.rng.randint(-10, 10)
while param == 0: # Avoid trivial operations
param = self.rng.randint(-10, 10)
parameters.append(param)
else:
parameters.append(0) # Some operations don't need parameters
return PatternRule(operations, parameters)
def is_interesting(self, sequence: List[int], max_value: int = 1000) -> bool:
"""Check if sequence is interesting enough"""
if not sequence:
return False
# Avoid too large numbers
if any(abs(x) > max_value for x in sequence):
return False
# Avoid constant sequences
if len(set(sequence)) == 1:
return False
# Avoid simple arithmetic progressions if complexity > 1
if self.complexity > 1:
diffs = [sequence[i+1] - sequence[i] for i in range(len(sequence)-1)]
if len(set(diffs)) == 1:
return False
return True
class SequenceDataset:
"""Generates number sequence completion tasks with dynamic pattern generation"""
def __init__(self, config: SequenceConfig):
self.config = config
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
"""Make the dataset iterable"""
self._current_idx = 0
return self
def __next__(self):
"""Get next item in iteration"""
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict:
"""Generate a sequence task with a newly generated pattern"""
rng = Random(self.seed + idx)
# Create pattern generator with random complexity
complexity = rng.randint(1, self.config.max_complexity)
generator = PatternGenerator(rng, complexity)
# Generate pattern rule and sequence
max_attempts = 10
for _ in range(max_attempts):
rule = generator.generate_rule()
# Generate initial terms
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
sequence = [rng.randint(-10, 10)] # Start with random number
# Generate remaining terms
try:
for i in range(1, num_terms + 1): # +1 for answer
next_term = rule.apply(sequence, i)
sequence.append(next_term)
if generator.is_interesting(sequence):
break
except (OverflowError, ZeroDivisionError):
continue
else:
# If we couldn't generate an interesting sequence, fall back to simple addition
rule = PatternRule([Operation.ADD], [2])
sequence = [i * 2 for i in range(num_terms + 1)]
visible_terms = sequence[:-1] # Last term is the answer
return {
"question": ", ".join(map(str, visible_terms)) + ", ?",
"answer": str(sequence[-1]),
"metadata": {
"rule": rule.to_string(),
"complexity": complexity,
"sequence": sequence
}
}
def sequence_dataset(
min_terms: int = 4,
max_terms: int = 8,
min_value: int = -100,
max_value: int = 100,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> SequenceDataset:
"""Create a SequenceDataset with the given configuration."""
config = SequenceConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return SequenceDataset(config)
import pytest
from reasoning_gym.cognition.sequences import (
SequenceDataset,
SequenceConfig,
Operation,
PatternRule,
PatternGenerator
)
def test_sequence_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = SequenceConfig(min_terms=3) # Too few terms
config.validate()
with pytest.raises(AssertionError):
config = SequenceConfig(min_terms=6, max_terms=5)
config.validate()
with pytest.raises(AssertionError):
config = SequenceConfig(min_value=100, max_value=0)
config.validate()
def test_pattern_rule():
"""Test pattern rule application"""
# Test simple addition
rule = PatternRule([Operation.ADD], [2])
assert rule.apply([1, 3], 1) == 5
# Test composition
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
def test_sequence_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = SequenceConfig(seed=42, size=10)
dataset1 = SequenceDataset(config)
dataset2 = SequenceDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_sequence_dataset_items():
"""Test basic properties of generated items"""
config = SequenceConfig(
min_terms=4,
max_terms=6,
max_complexity=2,
size=50,
seed=42
)
dataset = SequenceDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify sequence format
question = item["question"]
assert question.endswith(", ?")
terms = [int(x) for x in question[:-3].split(", ")]
assert len(terms) >= config.min_terms
assert len(terms) <= config.max_terms
def test_sequence_dataset_iteration():
"""Test that iteration respects dataset size"""
config = SequenceConfig(size=5, seed=42)
dataset = SequenceDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
"""Propositional logic task generator"""
from dataclasses import dataclass
from enum import Enum
from random import Random
from typing import Any, List, Optional, Set, Tuple
class Operator(Enum):
"""Basic logical operators"""
AND = "∧"
OR = ""
NOT = "¬"
IMPLIES = "→"
IFF = "↔"
@dataclass
class PropositionalLogicConfig:
"""Configuration for propositional logic task generation"""
min_vars: int = 2 # Minimum number of variables
max_vars: int = 4 # Maximum number of variables
min_statements: int = 2 # Minimum number of given statements
max_statements: int = 4 # Maximum number of statements
max_complexity: int = 3 # Maximum operator depth
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert self.min_vars > 0, "min_vars must be positive"
assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars"
assert self.min_statements > 0, "min_statements must be positive"
assert self.max_statements >= self.min_statements
assert self.max_complexity > 0, "max_complexity must be positive"
class Expression:
"""Represents a logical expression that can be evaluated"""
def __init__(self, operator: Optional[Operator], left: Any, right: Optional[Any] = None):
self.operator = operator
self.left = left
self.right = right
def evaluate(self, assignments: dict[str, bool]) -> bool:
"""Evaluate expression with given variable assignments"""
if self.operator is None:
return assignments[self.left] # Variable
elif self.operator == Operator.NOT:
return not self.left.evaluate(assignments)
elif self.operator == Operator.AND:
return self.left.evaluate(assignments) and self.right.evaluate(assignments)
elif self.operator == Operator.OR:
return self.left.evaluate(assignments) or self.right.evaluate(assignments)
elif self.operator == Operator.IMPLIES:
return (not self.left.evaluate(assignments)) or self.right.evaluate(assignments)
elif self.operator == Operator.IFF:
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
raise ValueError(f"Unknown operator: {self.operator}")
def __str__(self) -> str:
if self.operator is None:
return self.left
elif self.operator == Operator.NOT:
return f"{self.operator.value}{self.left}"
else:
return f"({self.left} {self.operator.value} {self.right})"
class PropositionalLogicDataset:
"""Generates propositional logic reasoning tasks"""
def __init__(self, config: PropositionalLogicConfig):
self.config = config
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single propositional logic task"""
rng = Random(self.seed + idx)
# Generate random variables
num_vars = rng.randint(self.config.min_vars, self.config.max_vars)
variables = [chr(ord('P') + i) for i in range(num_vars)]
# Generate premises
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
premises = self._generate_premises(rng, variables, num_statements)
# Generate a valid conclusion
conclusion = self._find_valid_conclusion(rng, premises, variables)
# Format question
question = "Given:\n"
for i, premise in enumerate(premises, 1):
question += f"{i}. {premise}\n"
question += "What can we conclude?"
return {
"question": question,
"answer": str(conclusion),
"metadata": {
"premises": [str(p) for p in premises],
"variables": variables,
"complexity": self._measure_complexity(conclusion)
}
}
def _generate_premises(self, rng: Random, variables: List[str], num_statements: int) -> List[Expression]:
"""Generate a list of premise statements"""
premises = []
for _ in range(num_statements):
depth = rng.randint(1, self.config.max_complexity)
premises.append(self._generate_expression(rng, variables, depth))
return premises
def _generate_expression(self, rng: Random, variables: List[str], depth: int) -> Expression:
"""Generate a random logical expression"""
if depth <= 1:
return Expression(None, rng.choice(variables))
operator = rng.choice(list(Operator))
if operator == Operator.NOT:
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
else:
left = self._generate_expression(rng, variables, depth - 1)
right = self._generate_expression(rng, variables, depth - 1)
return Expression(operator, left, right)
def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression:
"""Find a valid conclusion that follows from the premises"""
# Try random conclusions until we find a valid one
for _ in range(100):
candidate = self._generate_expression(rng, variables, 2)
if self._is_valid_conclusion(premises, candidate):
return candidate
# Fallback to a simple conclusion
return Expression(None, variables[0])
def _is_valid_conclusion(self, premises: List[Expression], conclusion: Expression) -> bool:
"""Check if conclusion follows from premises using truth tables"""
variables = self._collect_variables(premises + [conclusion])
# Check all possible assignments
for assignment in self._generate_assignments(variables):
# If premises are true but conclusion is false, invalid
if all(p.evaluate(assignment) for p in premises) and not conclusion.evaluate(assignment):
return False
return True
def _collect_variables(self, expressions: List[Expression]) -> Set[str]:
"""Collect all variables used in expressions"""
variables = set()
for expr in expressions:
if expr.operator is None:
variables.add(expr.left)
else:
if isinstance(expr.left, Expression):
variables.update(self._collect_variables([expr.left]))
if expr.right and isinstance(expr.right, Expression):
variables.update(self._collect_variables([expr.right]))
return variables
def _generate_assignments(self, variables: Set[str]) -> List[dict[str, bool]]:
"""Generate all possible truth value assignments"""
assignments = []
for i in range(2 ** len(variables)):
assignment = {}
for j, var in enumerate(sorted(variables)):
assignment[var] = bool((i >> j) & 1)
assignments.append(assignment)
return assignments
def _measure_complexity(self, expression: Expression) -> int:
"""Measure the complexity of an expression"""
if expression.operator is None:
return 1
elif expression.operator == Operator.NOT:
return 1 + self._measure_complexity(expression.left)
else:
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
def propositional_logic_dataset(
min_vars: int = 2,
max_vars: int = 4,
min_statements: int = 2,
max_statements: int = 4,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> PropositionalLogicDataset:
"""Create a PropositionalLogicDataset with the given configuration."""
config = PropositionalLogicConfig(
min_vars=min_vars,
max_vars=max_vars,
min_statements=min_statements,
max_statements=max_statements,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return PropositionalLogicDataset(config)
"""Tests for propositional logic task generation"""
import pytest
from reasoning_gym.logic.propositional_logic import (
Expression,
Operator,
PropositionalLogicConfig,
PropositionalLogicDataset,
)
def test_propositional_logic_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = PropositionalLogicConfig(min_vars=0)
config.validate()
with pytest.raises(AssertionError):
config = PropositionalLogicConfig(min_vars=4, max_vars=3)
config.validate()
with pytest.raises(AssertionError):
config = PropositionalLogicConfig(min_statements=0)
config.validate()
def test_expression_evaluation():
"""Test logical expression evaluation"""
# Test simple variable
expr = Expression(None, "P")
assert expr.evaluate({"P": True}) is True
assert expr.evaluate({"P": False}) is False
# Test NOT
expr = Expression(Operator.NOT, Expression(None, "P"))
assert expr.evaluate({"P": True}) is False
assert expr.evaluate({"P": False}) is True
# Test AND
expr = Expression(
Operator.AND,
Expression(None, "P"),
Expression(None, "Q")
)
assert expr.evaluate({"P": True, "Q": True}) is True
assert expr.evaluate({"P": True, "Q": False}) is False
# Test IMPLIES
expr = Expression(
Operator.IMPLIES,
Expression(None, "P"),
Expression(None, "Q")
)
assert expr.evaluate({"P": True, "Q": False}) is False
assert expr.evaluate({"P": True, "Q": True}) is True
assert expr.evaluate({"P": False, "Q": False}) is True
def test_propositional_logic_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PropositionalLogicConfig(seed=42, size=10)
dataset1 = PropositionalLogicDataset(config)
dataset2 = PropositionalLogicDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_propositional_logic_dataset_items():
"""Test basic properties of generated items"""
config = PropositionalLogicConfig(
min_vars=2,
max_vars=3,
min_statements=2,
max_statements=3,
max_complexity=2,
size=10,
seed=42
)
dataset = PropositionalLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
assert isinstance(item["metadata"]["premises"], list)
assert isinstance(item["metadata"]["variables"], list)
assert isinstance(item["metadata"]["complexity"], int)
def test_propositional_logic_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PropositionalLogicConfig(size=5, seed=42)
dataset = PropositionalLogicDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
"""
Logic tasks for training reasoning capabilities:
- Propositional logic
- Predicate logic
- Set theory
- Syllogisms
"""
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset, propositional_logic_dataset
__all__ = ["PropositionalLogicConfig", "PropositionalLogicDataset", "propositional_logic_dataset"]