reimplemented prop logic

This commit is contained in:
joesharratt1229 2025-02-20 23:59:31 +00:00
parent ed10c5f9bc
commit 39ee099a86

View file

@ -1,13 +1,57 @@
"""Propositional logic task generator"""
import re
from dataclasses import dataclass
from enum import StrEnum
from random import Random
from typing import Any, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
from ..factory import ProceduralDataset, register_dataset
def parse_expr(expr: str):
expr = expr.strip()
if not expr:
raise ValueError("Empty expression")
if expr[0] == "(" and expr[-1] == ")":
level = 0
valid_enclosure = True
for char in expr[1:-1]:
if char == "(":
level += 1
elif char == ")":
level -= 1
if level < 0:
valid_enclosure = False
break
if level == 0 and valid_enclosure:
return parse_expr(expr[1:-1])
operators_by_precedence = [[Operator.IFF], [Operator.IMPLIES], [Operator.OR], [Operator.AND]]
for operator_level in operators_by_precedence:
level = 0
for i in range(len(expr) - 1, -1, -1):
char = expr[i]
if char == ")":
level += 1
elif char == "(":
level -= 1
elif level == 0:
for operator in operator_level:
if expr[i : i + len(operator.value)] == operator.value:
left_expr = expr[:i]
right_expr = expr[i + len(operator.value) :]
return Expression(operator, parse_expr(left_expr), parse_expr(right_expr))
if expr.startswith(Operator.NOT.value):
sub_expr = expr[len(Operator.NOT.value) :]
return Expression(Operator.NOT, parse_expr(sub_expr))
return Expression(None, expr)
class Operator(StrEnum):
"""Basic logical operators"""
@ -18,6 +62,24 @@ class Operator(StrEnum):
IFF = ""
QUESTION_FORMAT = "\n".join(
[
"The following question is a propositional logic reasoning question.",
"In the question we provide a list of premises",
"The task is to infer a correct conclusion from the premise.",
"FORMAT INSTRUCTIONS:",
"Return the conclusion logic statement, as your final answer.",
"Use the following notation to denote symbols",
"OR = \u2228",
"AND = \u2227",
"IMPLIES = \u2192",
"IFF = \u2194",
"NOT = \u00ac",
"Here is the question:",
]
)
@dataclass
class PropositionalLogicConfig:
"""Configuration for propositional logic task generation"""
@ -63,6 +125,42 @@ class Expression:
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
raise ValueError(f"Unknown operator: {self.operator}")
@classmethod
def from_string(cls, expr: str) -> "Expression":
parsed_expr = parse_expr(expr)
return cls(parsed_expr.operator, parsed_expr.left, parsed_expr.right)
def simplify(self):
if self.operator is None:
return self
simplified_left = self.left.simplify() if isinstance(self.left, Expression) else self.left
simplified_right = self.right.simplify() if self.right and isinstance(self.right, Expression) else self.right
if self.operator == Operator.NOT:
if isinstance(simplified_left, Expression) and simplified_left.operator == Operator.NOT:
return simplified_left.left
return Expression(Operator.NOT, simplified_left)
if self.operator in {Operator.AND, Operator.OR}:
if simplified_left is False and self.operator == Operator.OR:
return simplified_right
if simplified_left is True and self.operator == Operator.AND:
return simplified_right
if (simplified_left is True and self.operator == Operator.OR) or (
simplified_left is False and self.operator == Operator.AND
):
return simplified_left
if simplified_left == simplified_right:
return simplified_left
if self.operator == Operator.IMPLIES:
return Expression(Operator.OR, Expression(Operator.NOT, simplified_left), simplified_right).simplify()
return Expression(self.operator, simplified_left, simplified_right)
def __str__(self) -> str:
if self.operator is None:
return self.left
@ -103,23 +201,23 @@ class PropositionalLogicDataset(ProceduralDataset):
# Generate premises
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
premises = self._generate_premises(rng, variables, num_statements)
# Generate a valid conclusion
conclusion = self._find_valid_conclusion(rng, premises, variables)
# Format question
question = "Given:\n"
question = QUESTION_FORMAT
question += "Given:\n"
for i, premise in enumerate(premises, 1):
question += f"{i}. {premise}\n"
question += "What can we conclude?"
question += f"{i}. {premise}\n."
question += "What can we conclude from the above statements?"
return {
"question": question,
"answer": str(conclusion),
"answer": None,
"metadata": {
"premises": [str(p) for p in premises],
"variables": variables,
"complexity": self._measure_complexity(conclusion),
"example_answer": str(conclusion),
},
}
@ -135,7 +233,6 @@ class PropositionalLogicDataset(ProceduralDataset):
"""Generate a random logical expression"""
if depth <= 1:
return Expression(None, rng.choice(variables))
operator = rng.choice(list(Operator))
if operator == Operator.NOT:
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
@ -146,10 +243,9 @@ class PropositionalLogicDataset(ProceduralDataset):
def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression:
"""Find a valid conclusion that follows from the premises"""
# Try random conclusions until we find a valid one
for _ in range(100):
candidate = self._generate_expression(rng, variables, 2)
if self._is_valid_conclusion(premises, candidate):
candidate = self._generate_expression(rng, variables, 2).simplify()
if self._is_valid_conclusion(premises, candidate) and not (self._is_trivial(candidate)):
return candidate
# Fallback to a simple conclusion
@ -198,5 +294,41 @@ class PropositionalLogicDataset(ProceduralDataset):
else:
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
def score_answer(self, answer: str | None, entry: Dict[str, Any]) -> float:
"""Robust scoring implementation for propositional logic answers"""
if not answer:
return 0.0
try:
cleaned_answer = answer
valid_vars = set(entry["metadata"]["variables"])
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
if any(var not in valid_vars for var in answer_vars):
return 0.01
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
answer_expr = Expression.from_string(cleaned_answer)
if self._is_valid_conclusion(premises, answer_expr):
return 1.0
elif self._is_trivial(answer_expr):
return 0.25
return 0.05
except (ValueError, KeyError, AttributeError):
return 0.01
def _is_trivial(self, expr: Expression) -> bool:
"""Check for trivial tautologies like P ¬P"""
if expr.operator is None:
return True
variables = self._collect_variables([expr])
for assignment in self._generate_assignments(variables):
if not expr.evaluate(assignment):
return False
return True
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)