mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
reimplemented prop logic
This commit is contained in:
parent
ed10c5f9bc
commit
39ee099a86
1 changed files with 143 additions and 11 deletions
|
|
@ -1,13 +1,57 @@
|
|||
"""Propositional logic task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from random import Random
|
||||
from typing import Any, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
def parse_expr(expr: str):
|
||||
expr = expr.strip()
|
||||
if not expr:
|
||||
raise ValueError("Empty expression")
|
||||
|
||||
if expr[0] == "(" and expr[-1] == ")":
|
||||
level = 0
|
||||
valid_enclosure = True
|
||||
for char in expr[1:-1]:
|
||||
if char == "(":
|
||||
level += 1
|
||||
elif char == ")":
|
||||
level -= 1
|
||||
if level < 0:
|
||||
valid_enclosure = False
|
||||
break
|
||||
if level == 0 and valid_enclosure:
|
||||
return parse_expr(expr[1:-1])
|
||||
|
||||
operators_by_precedence = [[Operator.IFF], [Operator.IMPLIES], [Operator.OR], [Operator.AND]]
|
||||
|
||||
for operator_level in operators_by_precedence:
|
||||
level = 0
|
||||
for i in range(len(expr) - 1, -1, -1):
|
||||
char = expr[i]
|
||||
if char == ")":
|
||||
level += 1
|
||||
elif char == "(":
|
||||
level -= 1
|
||||
elif level == 0:
|
||||
for operator in operator_level:
|
||||
if expr[i : i + len(operator.value)] == operator.value:
|
||||
left_expr = expr[:i]
|
||||
right_expr = expr[i + len(operator.value) :]
|
||||
return Expression(operator, parse_expr(left_expr), parse_expr(right_expr))
|
||||
|
||||
if expr.startswith(Operator.NOT.value):
|
||||
sub_expr = expr[len(Operator.NOT.value) :]
|
||||
return Expression(Operator.NOT, parse_expr(sub_expr))
|
||||
|
||||
return Expression(None, expr)
|
||||
|
||||
|
||||
class Operator(StrEnum):
|
||||
"""Basic logical operators"""
|
||||
|
||||
|
|
@ -18,6 +62,24 @@ class Operator(StrEnum):
|
|||
IFF = "↔"
|
||||
|
||||
|
||||
QUESTION_FORMAT = "\n".join(
|
||||
[
|
||||
"The following question is a propositional logic reasoning question.",
|
||||
"In the question we provide a list of premises",
|
||||
"The task is to infer a correct conclusion from the premise.",
|
||||
"FORMAT INSTRUCTIONS:",
|
||||
"Return the conclusion logic statement, as your final answer.",
|
||||
"Use the following notation to denote symbols",
|
||||
"OR = \u2228",
|
||||
"AND = \u2227",
|
||||
"IMPLIES = \u2192",
|
||||
"IFF = \u2194",
|
||||
"NOT = \u00ac",
|
||||
"Here is the question:",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PropositionalLogicConfig:
|
||||
"""Configuration for propositional logic task generation"""
|
||||
|
|
@ -63,6 +125,42 @@ class Expression:
|
|||
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
|
||||
raise ValueError(f"Unknown operator: {self.operator}")
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, expr: str) -> "Expression":
|
||||
parsed_expr = parse_expr(expr)
|
||||
return cls(parsed_expr.operator, parsed_expr.left, parsed_expr.right)
|
||||
|
||||
def simplify(self):
|
||||
if self.operator is None:
|
||||
return self
|
||||
|
||||
simplified_left = self.left.simplify() if isinstance(self.left, Expression) else self.left
|
||||
simplified_right = self.right.simplify() if self.right and isinstance(self.right, Expression) else self.right
|
||||
|
||||
if self.operator == Operator.NOT:
|
||||
if isinstance(simplified_left, Expression) and simplified_left.operator == Operator.NOT:
|
||||
return simplified_left.left
|
||||
return Expression(Operator.NOT, simplified_left)
|
||||
|
||||
if self.operator in {Operator.AND, Operator.OR}:
|
||||
if simplified_left is False and self.operator == Operator.OR:
|
||||
return simplified_right
|
||||
if simplified_left is True and self.operator == Operator.AND:
|
||||
return simplified_right
|
||||
|
||||
if (simplified_left is True and self.operator == Operator.OR) or (
|
||||
simplified_left is False and self.operator == Operator.AND
|
||||
):
|
||||
return simplified_left
|
||||
|
||||
if simplified_left == simplified_right:
|
||||
return simplified_left
|
||||
|
||||
if self.operator == Operator.IMPLIES:
|
||||
return Expression(Operator.OR, Expression(Operator.NOT, simplified_left), simplified_right).simplify()
|
||||
|
||||
return Expression(self.operator, simplified_left, simplified_right)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.operator is None:
|
||||
return self.left
|
||||
|
|
@ -103,23 +201,23 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
# Generate premises
|
||||
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
|
||||
premises = self._generate_premises(rng, variables, num_statements)
|
||||
|
||||
# Generate a valid conclusion
|
||||
conclusion = self._find_valid_conclusion(rng, premises, variables)
|
||||
|
||||
# Format question
|
||||
question = "Given:\n"
|
||||
question = QUESTION_FORMAT
|
||||
question += "Given:\n"
|
||||
for i, premise in enumerate(premises, 1):
|
||||
question += f"{i}. {premise}\n"
|
||||
question += "What can we conclude?"
|
||||
question += f"{i}. {premise}\n."
|
||||
question += "What can we conclude from the above statements?"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(conclusion),
|
||||
"answer": None,
|
||||
"metadata": {
|
||||
"premises": [str(p) for p in premises],
|
||||
"variables": variables,
|
||||
"complexity": self._measure_complexity(conclusion),
|
||||
"example_answer": str(conclusion),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -135,7 +233,6 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
"""Generate a random logical expression"""
|
||||
if depth <= 1:
|
||||
return Expression(None, rng.choice(variables))
|
||||
|
||||
operator = rng.choice(list(Operator))
|
||||
if operator == Operator.NOT:
|
||||
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
|
||||
|
|
@ -146,10 +243,9 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
|
||||
def _find_valid_conclusion(self, rng: Random, premises: List[Expression], variables: List[str]) -> Expression:
|
||||
"""Find a valid conclusion that follows from the premises"""
|
||||
# Try random conclusions until we find a valid one
|
||||
for _ in range(100):
|
||||
candidate = self._generate_expression(rng, variables, 2)
|
||||
if self._is_valid_conclusion(premises, candidate):
|
||||
candidate = self._generate_expression(rng, variables, 2).simplify()
|
||||
if self._is_valid_conclusion(premises, candidate) and not (self._is_trivial(candidate)):
|
||||
return candidate
|
||||
|
||||
# Fallback to a simple conclusion
|
||||
|
|
@ -198,5 +294,41 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
else:
|
||||
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
|
||||
|
||||
def score_answer(self, answer: str | None, entry: Dict[str, Any]) -> float:
|
||||
"""Robust scoring implementation for propositional logic answers"""
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
cleaned_answer = answer
|
||||
|
||||
valid_vars = set(entry["metadata"]["variables"])
|
||||
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
|
||||
if any(var not in valid_vars for var in answer_vars):
|
||||
return 0.01
|
||||
|
||||
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
|
||||
answer_expr = Expression.from_string(cleaned_answer)
|
||||
|
||||
if self._is_valid_conclusion(premises, answer_expr):
|
||||
return 1.0
|
||||
|
||||
elif self._is_trivial(answer_expr):
|
||||
return 0.25
|
||||
|
||||
return 0.05
|
||||
except (ValueError, KeyError, AttributeError):
|
||||
return 0.01
|
||||
|
||||
def _is_trivial(self, expr: Expression) -> bool:
|
||||
"""Check for trivial tautologies like P ∨ ¬P"""
|
||||
if expr.operator is None:
|
||||
return True
|
||||
variables = self._collect_variables([expr])
|
||||
for assignment in self._generate_assignments(variables):
|
||||
if not expr.evaluate(assignment):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue