mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
bump version, remove accidentially checked in file
This commit is contained in:
parent
3917990153
commit
4112f57ea2
2 changed files with 1 additions and 649 deletions
|
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "reasoning_gym"
|
name = "reasoning_gym"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Open-Thought community", email="andreas.koepf@xamla.com" },
|
{ name="Open-Thought community", email="andreas.koepf@xamla.com" },
|
||||||
]
|
]
|
||||||
|
|
|
||||||
648
python
648
python
|
|
@ -1,648 +0,0 @@
|
||||||
"""
|
|
||||||
Cognition tasks for training reasoning capabilities:
|
|
||||||
- Pattern recognition
|
|
||||||
- Sequence completion
|
|
||||||
- Logical reasoning
|
|
||||||
- Working memory
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .sequences import SequenceDataset, SequenceConfig, sequence_dataset
|
|
||||||
|
|
||||||
__all__ = ["SequenceDataset", "SequenceConfig", "sequence_dataset"]
|
|
||||||
"""
|
|
||||||
Cognition tasks for training reasoning capabilities:
|
|
||||||
- Pattern recognition
|
|
||||||
- Sequence completion
|
|
||||||
- Logical reasoning
|
|
||||||
- Working memory
|
|
||||||
"""
|
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from random import Random
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
class Operation(Enum):
|
|
||||||
"""Basic mathematical operations that can be composed"""
|
|
||||||
ADD = "+"
|
|
||||||
MULTIPLY = "*"
|
|
||||||
SQUARE = "^2"
|
|
||||||
DOUBLE = "*2"
|
|
||||||
HALF = "/2"
|
|
||||||
PREV_PLUS = "prev+" # Add previous number
|
|
||||||
ALTERNATE = "alt" # Alternate between operations
|
|
||||||
COMPOSE = "compose" # Compose two operations
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceConfig:
|
|
||||||
"""Configuration for sequence generation"""
|
|
||||||
min_terms: int = 4 # Minimum visible terms
|
|
||||||
max_terms: int = 8 # Maximum visible terms
|
|
||||||
min_value: int = -100 # Minimum allowed number
|
|
||||||
max_value: int = 100 # Maximum allowed number
|
|
||||||
max_complexity: int = 3 # Maximum number of operations to combine
|
|
||||||
seed: Optional[int] = None
|
|
||||||
size: int = 500 # Virtual dataset size
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
"""Validate configuration parameters"""
|
|
||||||
assert self.min_terms >= 4, "need at least 4 terms to establish pattern"
|
|
||||||
assert self.max_terms >= self.min_terms
|
|
||||||
assert self.max_value > self.min_value
|
|
||||||
assert self.max_complexity >= 1
|
|
||||||
|
|
||||||
class PatternRule:
|
|
||||||
"""Represents a composable sequence pattern rule"""
|
|
||||||
|
|
||||||
def __init__(self, operations: List[Operation], parameters: List[int]):
|
|
||||||
self.operations = operations
|
|
||||||
self.parameters = parameters
|
|
||||||
|
|
||||||
def apply(self, sequence: List[int], position: int) -> int:
|
|
||||||
"""Apply the rule to generate the next number"""
|
|
||||||
result = sequence[position] # Start with current number
|
|
||||||
|
|
||||||
for op, param in zip(self.operations, self.parameters):
|
|
||||||
if op == Operation.ADD:
|
|
||||||
result += param
|
|
||||||
elif op == Operation.MULTIPLY:
|
|
||||||
result *= param
|
|
||||||
elif op == Operation.SQUARE:
|
|
||||||
result = result * result
|
|
||||||
elif op == Operation.DOUBLE:
|
|
||||||
result *= 2
|
|
||||||
elif op == Operation.HALF:
|
|
||||||
result //= 2 # Integer division
|
|
||||||
elif op == Operation.PREV_PLUS:
|
|
||||||
if position > 0:
|
|
||||||
result += sequence[position - 1]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Convert rule to human-readable string"""
|
|
||||||
parts = []
|
|
||||||
for op, param in zip(self.operations, self.parameters):
|
|
||||||
if op == Operation.ADD:
|
|
||||||
parts.append(f"add {param}")
|
|
||||||
elif op == Operation.MULTIPLY:
|
|
||||||
parts.append(f"multiply by {param}")
|
|
||||||
elif op == Operation.SQUARE:
|
|
||||||
parts.append("square")
|
|
||||||
elif op == Operation.DOUBLE:
|
|
||||||
parts.append("double")
|
|
||||||
elif op == Operation.HALF:
|
|
||||||
parts.append("halve")
|
|
||||||
elif op == Operation.PREV_PLUS:
|
|
||||||
parts.append("add previous")
|
|
||||||
return " then ".join(parts)
|
|
||||||
|
|
||||||
class PatternGenerator:
|
|
||||||
"""Generates new pattern rules with configurable complexity"""
|
|
||||||
|
|
||||||
def __init__(self, rng: Random, complexity: int = 1):
|
|
||||||
self.rng = rng
|
|
||||||
self.complexity = complexity
|
|
||||||
|
|
||||||
def generate_rule(self) -> PatternRule:
|
|
||||||
"""Generate a new pattern rule"""
|
|
||||||
operations = []
|
|
||||||
parameters = []
|
|
||||||
|
|
||||||
# Number of operations based on complexity
|
|
||||||
num_ops = self.rng.randint(1, self.complexity + 1)
|
|
||||||
|
|
||||||
for _ in range(num_ops):
|
|
||||||
# Pick random operation
|
|
||||||
op = self.rng.choice(list(Operation))
|
|
||||||
operations.append(op)
|
|
||||||
|
|
||||||
# Generate appropriate parameter
|
|
||||||
if op in [Operation.ADD, Operation.MULTIPLY]:
|
|
||||||
param = self.rng.randint(-10, 10)
|
|
||||||
while param == 0: # Avoid trivial operations
|
|
||||||
param = self.rng.randint(-10, 10)
|
|
||||||
parameters.append(param)
|
|
||||||
else:
|
|
||||||
parameters.append(0) # Some operations don't need parameters
|
|
||||||
|
|
||||||
return PatternRule(operations, parameters)
|
|
||||||
|
|
||||||
def is_interesting(self, sequence: List[int], max_value: int = 1000) -> bool:
|
|
||||||
"""Check if sequence is interesting enough"""
|
|
||||||
if not sequence:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Avoid too large numbers
|
|
||||||
if any(abs(x) > max_value for x in sequence):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Avoid constant sequences
|
|
||||||
if len(set(sequence)) == 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Avoid simple arithmetic progressions if complexity > 1
|
|
||||||
if self.complexity > 1:
|
|
||||||
diffs = [sequence[i+1] - sequence[i] for i in range(len(sequence)-1)]
|
|
||||||
if len(set(diffs)) == 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
class SequenceDataset:
|
|
||||||
"""Generates number sequence completion tasks with dynamic pattern generation"""
|
|
||||||
|
|
||||||
def __init__(self, config: SequenceConfig):
|
|
||||||
self.config = config
|
|
||||||
self.config.validate()
|
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.config.size
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""Make the dataset iterable"""
|
|
||||||
self._current_idx = 0
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
"""Get next item in iteration"""
|
|
||||||
if self._current_idx >= self.config.size:
|
|
||||||
raise StopIteration
|
|
||||||
item = self[self._current_idx]
|
|
||||||
self._current_idx += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
|
||||||
"""Generate a sequence task with a newly generated pattern"""
|
|
||||||
rng = Random(self.seed + idx)
|
|
||||||
|
|
||||||
# Create pattern generator with random complexity
|
|
||||||
complexity = rng.randint(1, self.config.max_complexity)
|
|
||||||
generator = PatternGenerator(rng, complexity)
|
|
||||||
|
|
||||||
# Generate pattern rule and sequence
|
|
||||||
max_attempts = 10
|
|
||||||
for _ in range(max_attempts):
|
|
||||||
rule = generator.generate_rule()
|
|
||||||
|
|
||||||
# Generate initial terms
|
|
||||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
|
||||||
sequence = [rng.randint(-10, 10)] # Start with random number
|
|
||||||
|
|
||||||
# Generate remaining terms
|
|
||||||
try:
|
|
||||||
for i in range(1, num_terms + 1): # +1 for answer
|
|
||||||
next_term = rule.apply(sequence, i)
|
|
||||||
sequence.append(next_term)
|
|
||||||
|
|
||||||
if generator.is_interesting(sequence):
|
|
||||||
break
|
|
||||||
except (OverflowError, ZeroDivisionError):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# If we couldn't generate an interesting sequence, fall back to simple addition
|
|
||||||
rule = PatternRule([Operation.ADD], [2])
|
|
||||||
sequence = [i * 2 for i in range(num_terms + 1)]
|
|
||||||
|
|
||||||
visible_terms = sequence[:-1] # Last term is the answer
|
|
||||||
|
|
||||||
return {
|
|
||||||
"question": ", ".join(map(str, visible_terms)) + ", ?",
|
|
||||||
"answer": str(sequence[-1]),
|
|
||||||
"metadata": {
|
|
||||||
"rule": rule.to_string(),
|
|
||||||
"complexity": complexity,
|
|
||||||
"sequence": sequence
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def sequence_dataset(
|
|
||||||
min_terms: int = 4,
|
|
||||||
max_terms: int = 8,
|
|
||||||
min_value: int = -100,
|
|
||||||
max_value: int = 100,
|
|
||||||
max_complexity: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> SequenceDataset:
|
|
||||||
"""Create a SequenceDataset with the given configuration."""
|
|
||||||
config = SequenceConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
max_complexity=max_complexity,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return SequenceDataset(config)
|
|
||||||
import pytest
|
|
||||||
from reasoning_gym.cognition.sequences import (
|
|
||||||
SequenceDataset,
|
|
||||||
SequenceConfig,
|
|
||||||
Operation,
|
|
||||||
PatternRule,
|
|
||||||
PatternGenerator
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_sequence_config_validation():
|
|
||||||
"""Test that invalid configs raise appropriate errors"""
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = SequenceConfig(min_terms=3) # Too few terms
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = SequenceConfig(min_terms=6, max_terms=5)
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = SequenceConfig(min_value=100, max_value=0)
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
def test_pattern_rule():
|
|
||||||
"""Test pattern rule application"""
|
|
||||||
# Test simple addition
|
|
||||||
rule = PatternRule([Operation.ADD], [2])
|
|
||||||
assert rule.apply([1, 3], 1) == 5
|
|
||||||
|
|
||||||
# Test composition
|
|
||||||
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
|
|
||||||
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
|
||||||
|
|
||||||
def test_sequence_dataset_deterministic():
|
|
||||||
"""Test that dataset generates same items with same seed"""
|
|
||||||
config = SequenceConfig(seed=42, size=10)
|
|
||||||
dataset1 = SequenceDataset(config)
|
|
||||||
dataset2 = SequenceDataset(config)
|
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
|
||||||
assert dataset1[i] == dataset2[i]
|
|
||||||
|
|
||||||
def test_sequence_dataset_items():
|
|
||||||
"""Test basic properties of generated items"""
|
|
||||||
config = SequenceConfig(
|
|
||||||
min_terms=4,
|
|
||||||
max_terms=6,
|
|
||||||
max_complexity=2,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SequenceDataset(config)
|
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
item = dataset[i]
|
|
||||||
assert isinstance(item, dict)
|
|
||||||
assert "question" in item
|
|
||||||
assert "answer" in item
|
|
||||||
assert "metadata" in item
|
|
||||||
|
|
||||||
# Verify sequence format
|
|
||||||
question = item["question"]
|
|
||||||
assert question.endswith(", ?")
|
|
||||||
terms = [int(x) for x in question[:-3].split(", ")]
|
|
||||||
assert len(terms) >= config.min_terms
|
|
||||||
assert len(terms) <= config.max_terms
|
|
||||||
|
|
||||||
def test_sequence_dataset_iteration():
|
|
||||||
"""Test that iteration respects dataset size"""
|
|
||||||
config = SequenceConfig(size=5, seed=42)
|
|
||||||
dataset = SequenceDataset(config)
|
|
||||||
|
|
||||||
items = list(dataset)
|
|
||||||
assert len(items) == config.size
|
|
||||||
|
|
||||||
# Test multiple iterations yield same items
|
|
||||||
assert items == list(dataset)
|
|
||||||
"""Propositional logic task generator"""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from random import Random
|
|
||||||
from typing import Any, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class Operator(Enum):
|
|
||||||
"""Basic logical operators"""
|
|
||||||
AND = "∧"
|
|
||||||
OR = "∨"
|
|
||||||
NOT = "¬"
|
|
||||||
IMPLIES = "→"
|
|
||||||
IFF = "↔"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PropositionalLogicConfig:
|
|
||||||
"""Configuration for propositional logic task generation"""
|
|
||||||
min_vars: int = 2 # Minimum number of variables
|
|
||||||
max_vars: int = 4 # Maximum number of variables
|
|
||||||
min_statements: int = 2 # Minimum number of given statements
|
|
||||||
max_statements: int = 4 # Maximum number of statements
|
|
||||||
max_complexity: int = 3 # Maximum operator depth
|
|
||||||
seed: Optional[int] = None
|
|
||||||
size: int = 500 # Virtual dataset size
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
"""Validate configuration parameters"""
|
|
||||||
assert self.min_vars > 0, "min_vars must be positive"
|
|
||||||
assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars"
|
|
||||||
assert self.min_statements > 0, "min_statements must be positive"
|
|
||||||
assert self.max_statements >= self.min_statements
|
|
||||||
assert self.max_complexity > 0, "max_complexity must be positive"
|
|
||||||
|
|
||||||
|
|
||||||
class Expression:
|
|
||||||
"""Represents a logical expression that can be evaluated"""
|
|
||||||
|
|
||||||
def __init__(self, operator: Optional[Operator], left: Any, right: Optional[Any] = None):
|
|
||||||
self.operator = operator
|
|
||||||
self.left = left
|
|
||||||
self.right = right
|
|
||||||
|
|
||||||
def evaluate(self, assignments: dict[str, bool]) -> bool:
|
|
||||||
"""Evaluate expression with given variable assignments"""
|
|
||||||
if self.operator is None:
|
|
||||||
return assignments[self.left] # Variable
|
|
||||||
elif self.operator == Operator.NOT:
|
|
||||||
return not self.left.evaluate(assignments)
|
|
||||||
elif self.operator == Operator.AND:
|
|
||||||
return self.left.evaluate(assignments) and self.right.evaluate(assignments)
|
|
||||||
elif self.operator == Operator.OR:
|
|
||||||
return self.left.evaluate(assignments) or self.right.evaluate(assignments)
|
|
||||||
elif self.operator == Operator.IMPLIES:
|
|
||||||
return (not self.left.evaluate(assignments)) or self.right.evaluate(assignments)
|
|
||||||
elif self.operator == Operator.IFF:
|
|
||||||
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
|
|
||||||
raise ValueError(f"Unknown operator: {self.operator}")
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
if self.operator is None:
|
|
||||||
return self.left
|
|
||||||
elif self.operator == Operator.NOT:
|
|
||||||
return f"{self.operator.value}{self.left}"
|
|
||||||
else:
|
|
||||||
return f"({self.left} {self.operator.value} {self.right})"
|
|
||||||
|
|
||||||
|
|
||||||
class PropositionalLogicDataset:
|
|
||||||
"""Generates propositional logic reasoning tasks"""
|
|
||||||
|
|
||||||
def __init__(self, config: PropositionalLogicConfig):
|
|
||||||
self.config = config
|
|
||||||
self.config.validate()
|
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.config.size
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self._current_idx = 0
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self._current_idx >= self.config.size:
|
|
||||||
raise StopIteration
|
|
||||||
item = self[self._current_idx]
|
|
||||||
self._current_idx += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
||||||
"""Generate a single propositional logic task"""
|
|
||||||
rng = Random(self.seed + idx)
|
|
||||||
|
|
||||||
# Generate random variables
|
|
||||||
num_vars = rng.randint(self.config.min_vars, self.config.max_vars)
|
|
||||||
variables = [chr(ord('P') + i) for i in range(num_vars)]
|
|
||||||
|
|
||||||
# Generate premises
|
|
||||||
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
|
|
||||||
premises = self._generate_premises(rng, variables, num_statements)
|
|
||||||
|
|
||||||
# Generate a valid conclusion
|
|
||||||
conclusion = self._find_valid_conclusion(rng, premises, variables)
|
|
||||||
|
|
||||||
# Format question
|
|
||||||
question = "Given:\n"
|
|
||||||
for i, premise in enumerate(premises, 1):
|
|
||||||
question += f"{i}. {premise}\n"
|
|
||||||
question += "What can we conclude?"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"question": question,
|
|
||||||
"answer": str(conclusion),
|
|
||||||
"metadata": {
|
|
||||||
"premises": [str(p) for p in premises],
|
|
||||||
"variables": variables,
|
|
||||||
"complexity": self._measure_complexity(conclusion)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _generate_premises(self, rng: Random, variables: List[str], num_statements: int) -> List[Expression]:
|
|
||||||
"""Generate a list of premise statements"""
|
|
||||||
premises = []
|
|
||||||
for _ in range(num_statements):
|
|
||||||
depth = rng.randint(1, self.config.max_complexity)
|
|
||||||
premises.append(self._generate_expression(rng, variables, depth))
|
|
||||||
return premises
|
|
||||||
|
|
||||||
def _generate_expression(self, rng: Random, variables: List[str], depth: int) -> Expression:
|
|
||||||
"""Generate a random logical expression"""
|
|
||||||
if depth <= 1:
|
|
||||||
return Expression(None, rng.choice(variables))
|
|
||||||
|
|
||||||
operator = rng.choice(list(Operator))
|
|
||||||
if operator == Operator.NOT:
|
|
||||||
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
|
|
||||||
else:
|
|
||||||
left = self._generate_expression(rng, variables, depth - 1)
|
|
||||||
right = self._generate_expression(rng, variables, depth - 1)
|
|
||||||
return Expression(operator, left, right)
|
|
||||||
|
|
||||||
def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression:
|
|
||||||
"""Find a valid conclusion that follows from the premises"""
|
|
||||||
# Try random conclusions until we find a valid one
|
|
||||||
for _ in range(100):
|
|
||||||
candidate = self._generate_expression(rng, variables, 2)
|
|
||||||
if self._is_valid_conclusion(premises, candidate):
|
|
||||||
return candidate
|
|
||||||
|
|
||||||
# Fallback to a simple conclusion
|
|
||||||
return Expression(None, variables[0])
|
|
||||||
|
|
||||||
def _is_valid_conclusion(self, premises: List[Expression], conclusion: Expression) -> bool:
|
|
||||||
"""Check if conclusion follows from premises using truth tables"""
|
|
||||||
variables = self._collect_variables(premises + [conclusion])
|
|
||||||
|
|
||||||
# Check all possible assignments
|
|
||||||
for assignment in self._generate_assignments(variables):
|
|
||||||
# If premises are true but conclusion is false, invalid
|
|
||||||
if all(p.evaluate(assignment) for p in premises) and not conclusion.evaluate(assignment):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _collect_variables(self, expressions: List[Expression]) -> Set[str]:
|
|
||||||
"""Collect all variables used in expressions"""
|
|
||||||
variables = set()
|
|
||||||
for expr in expressions:
|
|
||||||
if expr.operator is None:
|
|
||||||
variables.add(expr.left)
|
|
||||||
else:
|
|
||||||
if isinstance(expr.left, Expression):
|
|
||||||
variables.update(self._collect_variables([expr.left]))
|
|
||||||
if expr.right and isinstance(expr.right, Expression):
|
|
||||||
variables.update(self._collect_variables([expr.right]))
|
|
||||||
return variables
|
|
||||||
|
|
||||||
def _generate_assignments(self, variables: Set[str]) -> List[dict[str, bool]]:
|
|
||||||
"""Generate all possible truth value assignments"""
|
|
||||||
assignments = []
|
|
||||||
for i in range(2 ** len(variables)):
|
|
||||||
assignment = {}
|
|
||||||
for j, var in enumerate(sorted(variables)):
|
|
||||||
assignment[var] = bool((i >> j) & 1)
|
|
||||||
assignments.append(assignment)
|
|
||||||
return assignments
|
|
||||||
|
|
||||||
def _measure_complexity(self, expression: Expression) -> int:
|
|
||||||
"""Measure the complexity of an expression"""
|
|
||||||
if expression.operator is None:
|
|
||||||
return 1
|
|
||||||
elif expression.operator == Operator.NOT:
|
|
||||||
return 1 + self._measure_complexity(expression.left)
|
|
||||||
else:
|
|
||||||
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
|
|
||||||
|
|
||||||
|
|
||||||
def propositional_logic_dataset(
|
|
||||||
min_vars: int = 2,
|
|
||||||
max_vars: int = 4,
|
|
||||||
min_statements: int = 2,
|
|
||||||
max_statements: int = 4,
|
|
||||||
max_complexity: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> PropositionalLogicDataset:
|
|
||||||
"""Create a PropositionalLogicDataset with the given configuration."""
|
|
||||||
config = PropositionalLogicConfig(
|
|
||||||
min_vars=min_vars,
|
|
||||||
max_vars=max_vars,
|
|
||||||
min_statements=min_statements,
|
|
||||||
max_statements=max_statements,
|
|
||||||
max_complexity=max_complexity,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return PropositionalLogicDataset(config)
|
|
||||||
"""Tests for propositional logic task generation"""
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from reasoning_gym.logic.propositional_logic import (
|
|
||||||
Expression,
|
|
||||||
Operator,
|
|
||||||
PropositionalLogicConfig,
|
|
||||||
PropositionalLogicDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_propositional_logic_config_validation():
|
|
||||||
"""Test that invalid configs raise appropriate errors"""
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = PropositionalLogicConfig(min_vars=0)
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = PropositionalLogicConfig(min_vars=4, max_vars=3)
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = PropositionalLogicConfig(min_statements=0)
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
|
|
||||||
def test_expression_evaluation():
|
|
||||||
"""Test logical expression evaluation"""
|
|
||||||
# Test simple variable
|
|
||||||
expr = Expression(None, "P")
|
|
||||||
assert expr.evaluate({"P": True}) is True
|
|
||||||
assert expr.evaluate({"P": False}) is False
|
|
||||||
|
|
||||||
# Test NOT
|
|
||||||
expr = Expression(Operator.NOT, Expression(None, "P"))
|
|
||||||
assert expr.evaluate({"P": True}) is False
|
|
||||||
assert expr.evaluate({"P": False}) is True
|
|
||||||
|
|
||||||
# Test AND
|
|
||||||
expr = Expression(
|
|
||||||
Operator.AND,
|
|
||||||
Expression(None, "P"),
|
|
||||||
Expression(None, "Q")
|
|
||||||
)
|
|
||||||
assert expr.evaluate({"P": True, "Q": True}) is True
|
|
||||||
assert expr.evaluate({"P": True, "Q": False}) is False
|
|
||||||
|
|
||||||
# Test IMPLIES
|
|
||||||
expr = Expression(
|
|
||||||
Operator.IMPLIES,
|
|
||||||
Expression(None, "P"),
|
|
||||||
Expression(None, "Q")
|
|
||||||
)
|
|
||||||
assert expr.evaluate({"P": True, "Q": False}) is False
|
|
||||||
assert expr.evaluate({"P": True, "Q": True}) is True
|
|
||||||
assert expr.evaluate({"P": False, "Q": False}) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_propositional_logic_dataset_deterministic():
|
|
||||||
"""Test that dataset generates same items with same seed"""
|
|
||||||
config = PropositionalLogicConfig(seed=42, size=10)
|
|
||||||
dataset1 = PropositionalLogicDataset(config)
|
|
||||||
dataset2 = PropositionalLogicDataset(config)
|
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
|
||||||
assert dataset1[i] == dataset2[i]
|
|
||||||
|
|
||||||
|
|
||||||
def test_propositional_logic_dataset_items():
|
|
||||||
"""Test basic properties of generated items"""
|
|
||||||
config = PropositionalLogicConfig(
|
|
||||||
min_vars=2,
|
|
||||||
max_vars=3,
|
|
||||||
min_statements=2,
|
|
||||||
max_statements=3,
|
|
||||||
max_complexity=2,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = PropositionalLogicDataset(config)
|
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
item = dataset[i]
|
|
||||||
assert isinstance(item, dict)
|
|
||||||
assert "question" in item
|
|
||||||
assert "answer" in item
|
|
||||||
assert "metadata" in item
|
|
||||||
assert isinstance(item["metadata"]["premises"], list)
|
|
||||||
assert isinstance(item["metadata"]["variables"], list)
|
|
||||||
assert isinstance(item["metadata"]["complexity"], int)
|
|
||||||
|
|
||||||
|
|
||||||
def test_propositional_logic_dataset_iteration():
|
|
||||||
"""Test that iteration respects dataset size"""
|
|
||||||
config = PropositionalLogicConfig(size=5, seed=42)
|
|
||||||
dataset = PropositionalLogicDataset(config)
|
|
||||||
|
|
||||||
items = list(dataset)
|
|
||||||
assert len(items) == config.size
|
|
||||||
|
|
||||||
# Test multiple iterations yield same items
|
|
||||||
assert items == list(dataset)
|
|
||||||
"""
|
|
||||||
Logic tasks for training reasoning capabilities:
|
|
||||||
- Propositional logic
|
|
||||||
- Predicate logic
|
|
||||||
- Set theory
|
|
||||||
- Syllogisms
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset, propositional_logic_dataset
|
|
||||||
|
|
||||||
__all__ = ["PropositionalLogicConfig", "PropositionalLogicDataset", "propositional_logic_dataset"]
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue