diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index 395c919f..b289f142 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -1,13 +1,57 @@ """Propositional logic task generator""" +import re from dataclasses import dataclass from enum import StrEnum from random import Random -from typing import Any, List, Optional, Set +from typing import Any, Dict, List, Optional, Set from ..factory import ProceduralDataset, register_dataset +def parse_expr(expr: str): + expr = expr.strip() + if not expr: + raise ValueError("Empty expression") + + if expr[0] == "(" and expr[-1] == ")": + level = 0 + valid_enclosure = True + for char in expr[1:-1]: + if char == "(": + level += 1 + elif char == ")": + level -= 1 + if level < 0: + valid_enclosure = False + break + if level == 0 and valid_enclosure: + return parse_expr(expr[1:-1]) + + operators_by_precedence = [[Operator.IFF], [Operator.IMPLIES], [Operator.OR], [Operator.AND]] + + for operator_level in operators_by_precedence: + level = 0 + for i in range(len(expr) - 1, -1, -1): + char = expr[i] + if char == ")": + level += 1 + elif char == "(": + level -= 1 + elif level == 0: + for operator in operator_level: + if expr[i : i + len(operator.value)] == operator.value: + left_expr = expr[:i] + right_expr = expr[i + len(operator.value) :] + return Expression(operator, parse_expr(left_expr), parse_expr(right_expr)) + + if expr.startswith(Operator.NOT.value): + sub_expr = expr[len(Operator.NOT.value) :] + return Expression(Operator.NOT, parse_expr(sub_expr)) + + return Expression(None, expr) + + class Operator(StrEnum): """Basic logical operators""" @@ -18,6 +62,24 @@ class Operator(StrEnum): IFF = "↔" +QUESTION_FORMAT = "\n".join( + [ + "The following question is a propositional logic reasoning question.", + "In the question we provide a list of premises", + "The task is to infer a correct conclusion from the premise.", + "FORMAT INSTRUCTIONS:", + "Return the conclusion logic statement, as your final answer.", + "Use the following notation to denote symbols", + "OR = \u2228", + "AND = \u2227", + "IMPLIES = \u2192", + "IFF = \u2194", + "NOT = \u00ac", + "Here is the question:", + ] +) + + @dataclass class PropositionalLogicConfig: """Configuration for propositional logic task generation""" @@ -63,6 +125,42 @@ class Expression: return self.left.evaluate(assignments) == self.right.evaluate(assignments) raise ValueError(f"Unknown operator: {self.operator}") + @classmethod + def from_string(cls, expr: str) -> "Expression": + parsed_expr = parse_expr(expr) + return cls(parsed_expr.operator, parsed_expr.left, parsed_expr.right) + + def simplify(self): + if self.operator is None: + return self + + simplified_left = self.left.simplify() if isinstance(self.left, Expression) else self.left + simplified_right = self.right.simplify() if self.right and isinstance(self.right, Expression) else self.right + + if self.operator == Operator.NOT: + if isinstance(simplified_left, Expression) and simplified_left.operator == Operator.NOT: + return simplified_left.left + return Expression(Operator.NOT, simplified_left) + + if self.operator in {Operator.AND, Operator.OR}: + if simplified_left is False and self.operator == Operator.OR: + return simplified_right + if simplified_left is True and self.operator == Operator.AND: + return simplified_right + + if (simplified_left is True and self.operator == Operator.OR) or ( + simplified_left is False and self.operator == Operator.AND + ): + return simplified_left + + if simplified_left == simplified_right: + return simplified_left + + if self.operator == Operator.IMPLIES: + return Expression(Operator.OR, Expression(Operator.NOT, simplified_left), simplified_right).simplify() + + return Expression(self.operator, simplified_left, simplified_right) + def __str__(self) -> str: if self.operator is None: return self.left @@ -103,23 +201,23 @@ class PropositionalLogicDataset(ProceduralDataset): # Generate premises num_statements = rng.randint(self.config.min_statements, self.config.max_statements) premises = self._generate_premises(rng, variables, num_statements) - - # Generate a valid conclusion conclusion = self._find_valid_conclusion(rng, premises, variables) # Format question - question = "Given:\n" + question = QUESTION_FORMAT + question += "Given:\n" for i, premise in enumerate(premises, 1): - question += f"{i}. {premise}\n" - question += "What can we conclude?" + question += f"{i}. {premise}\n." + question += "What can we conclude from the above statements?" return { "question": question, - "answer": str(conclusion), + "answer": None, "metadata": { "premises": [str(p) for p in premises], "variables": variables, "complexity": self._measure_complexity(conclusion), + "example_answer": str(conclusion), }, } @@ -135,7 +233,6 @@ class PropositionalLogicDataset(ProceduralDataset): """Generate a random logical expression""" if depth <= 1: return Expression(None, rng.choice(variables)) - operator = rng.choice(list(Operator)) if operator == Operator.NOT: return Expression(operator, self._generate_expression(rng, variables, depth - 1)) @@ -146,10 +243,9 @@ class PropositionalLogicDataset(ProceduralDataset): def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression: """Find a valid conclusion that follows from the premises""" - # Try random conclusions until we find a valid one for _ in range(100): - candidate = self._generate_expression(rng, variables, 2) - if self._is_valid_conclusion(premises, candidate): + candidate = self._generate_expression(rng, variables, 2).simplify() + if self._is_valid_conclusion(premises, candidate) and not (self._is_trivial(candidate)): return candidate # Fallback to a simple conclusion @@ -198,5 +294,41 @@ class PropositionalLogicDataset(ProceduralDataset): else: return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right) + def score_answer(self, answer: str | None, entry: Dict[str, Any]) -> float: + """Robust scoring implementation for propositional logic answers""" + if not answer: + return 0.0 + + try: + cleaned_answer = answer + + valid_vars = set(entry["metadata"]["variables"]) + answer_vars = re.findall(r"([A-Z])", cleaned_answer) + if any(var not in valid_vars for var in answer_vars): + return 0.01 + + premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]] + answer_expr = Expression.from_string(cleaned_answer) + + if self._is_valid_conclusion(premises, answer_expr): + return 1.0 + + elif self._is_trivial(answer_expr): + return 0.25 + + return 0.05 + except (ValueError, KeyError, AttributeError): + return 0.01 + + def _is_trivial(self, expr: Expression) -> bool: + """Check for trivial tautologies like P ∨ ¬P""" + if expr.operator is None: + return True + variables = self._collect_variables([expr]) + for assignment in self._generate_assignments(variables): + if not expr.evaluate(assignment): + return False + return True + register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)