reasoning-gym/reasoning_gym/logic/propositional_logic.py
2025-01-26 16:55:17 +01:00

202 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Propositional logic task generator"""
from dataclasses import dataclass
from enum import Enum
from random import Random
from typing import Any, List, Optional, Set
from ..factory import ProceduralDataset, register_dataset
class Operator(Enum):
"""Basic logical operators"""
AND = ""
OR = ""
NOT = "¬"
IMPLIES = ""
IFF = ""
@dataclass
class PropositionalLogicConfig:
"""Configuration for propositional logic task generation"""
min_vars: int = 2 # Minimum number of variables
max_vars: int = 4 # Maximum number of variables
min_statements: int = 2 # Minimum number of given statements
max_statements: int = 4 # Maximum number of statements
max_complexity: int = 3 # Maximum operator depth
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert self.min_vars > 0, "min_vars must be positive"
assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars"
assert self.min_statements > 0, "min_statements must be positive"
assert self.max_statements >= self.min_statements
assert self.max_complexity > 0, "max_complexity must be positive"
class Expression:
"""Represents a logical expression that can be evaluated"""
def __init__(self, operator: Optional[Operator], left: Any, right: Optional[Any] = None):
self.operator = operator
self.left = left
self.right = right
def evaluate(self, assignments: dict[str, bool]) -> bool:
"""Evaluate expression with given variable assignments"""
if self.operator is None:
return assignments[self.left] # Variable
elif self.operator == Operator.NOT:
return not self.left.evaluate(assignments)
elif self.operator == Operator.AND:
return self.left.evaluate(assignments) and self.right.evaluate(assignments)
elif self.operator == Operator.OR:
return self.left.evaluate(assignments) or self.right.evaluate(assignments)
elif self.operator == Operator.IMPLIES:
return (not self.left.evaluate(assignments)) or self.right.evaluate(assignments)
elif self.operator == Operator.IFF:
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
raise ValueError(f"Unknown operator: {self.operator}")
def __str__(self) -> str:
if self.operator is None:
return self.left
elif self.operator == Operator.NOT:
return f"{self.operator.value}{self.left}"
else:
return f"({self.left} {self.operator.value} {self.right})"
class PropositionalLogicDataset(ProceduralDataset):
"""Generates propositional logic reasoning tasks"""
def __init__(self, config: PropositionalLogicConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single propositional logic task"""
rng = Random(self.seed + idx)
# Generate random variables
num_vars = rng.randint(self.config.min_vars, self.config.max_vars)
variables = [chr(ord("P") + i) for i in range(num_vars)]
# Generate premises
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
premises = self._generate_premises(rng, variables, num_statements)
# Generate a valid conclusion
conclusion = self._find_valid_conclusion(rng, premises, variables)
# Format question
question = "Given:\n"
for i, premise in enumerate(premises, 1):
question += f"{i}. {premise}\n"
question += "What can we conclude?"
return {
"question": question,
"answer": str(conclusion),
"metadata": {
"premises": [str(p) for p in premises],
"variables": variables,
"complexity": self._measure_complexity(conclusion),
},
}
def _generate_premises(self, rng: Random, variables: List[str], num_statements: int) -> List[Expression]:
"""Generate a list of premise statements"""
premises = []
for _ in range(num_statements):
depth = rng.randint(1, self.config.max_complexity)
premises.append(self._generate_expression(rng, variables, depth))
return premises
def _generate_expression(self, rng: Random, variables: List[str], depth: int) -> Expression:
"""Generate a random logical expression"""
if depth <= 1:
return Expression(None, rng.choice(variables))
operator = rng.choice(list(Operator))
if operator == Operator.NOT:
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
else:
left = self._generate_expression(rng, variables, depth - 1)
right = self._generate_expression(rng, variables, depth - 1)
return Expression(operator, left, right)
def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression:
"""Find a valid conclusion that follows from the premises"""
# Try random conclusions until we find a valid one
for _ in range(100):
candidate = self._generate_expression(rng, variables, 2)
if self._is_valid_conclusion(premises, candidate):
return candidate
# Fallback to a simple conclusion
return Expression(None, variables[0])
def _is_valid_conclusion(self, premises: List[Expression], conclusion: Expression) -> bool:
"""Check if conclusion follows from premises using truth tables"""
variables = self._collect_variables(premises + [conclusion])
# Check all possible assignments
for assignment in self._generate_assignments(variables):
# If premises are true but conclusion is false, invalid
if all(p.evaluate(assignment) for p in premises) and not conclusion.evaluate(assignment):
return False
return True
def _collect_variables(self, expressions: List[Expression]) -> Set[str]:
"""Collect all variables used in expressions"""
variables = set()
for expr in expressions:
if expr.operator is None:
variables.add(expr.left)
else:
if isinstance(expr.left, Expression):
variables.update(self._collect_variables([expr.left]))
if expr.right and isinstance(expr.right, Expression):
variables.update(self._collect_variables([expr.right]))
return variables
def _generate_assignments(self, variables: Set[str]) -> List[dict[str, bool]]:
"""Generate all possible truth value assignments"""
assignments = []
for i in range(2 ** len(variables)):
assignment = {}
for j, var in enumerate(sorted(variables)):
assignment[var] = bool((i >> j) & 1)
assignments.append(assignment)
return assignments
def _measure_complexity(self, expression: Expression) -> int:
"""Measure the complexity of an expression"""
if expression.operator is None:
return 1
elif expression.operator == Operator.NOT:
return 1 + self._measure_complexity(expression.left)
else:
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)