feat(env): Propositional Logic Curriculum (#365)

* propositional logic curriculum

* lint

* difficulty meta
This commit is contained in:
Zafir Stojanovski 2025-03-14 16:12:39 +01:00 committed by GitHub
parent 1b6b13566a
commit 75dfd8ffed
3 changed files with 85 additions and 5 deletions

View file

@ -6,6 +6,7 @@ from enum import StrEnum
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -87,6 +88,7 @@ class PropositionalLogicConfig:
max_vars: int = 4 # Maximum number of variables
min_statements: int = 2 # Minimum number of given statements
max_statements: int = 4 # Maximum number of statements
min_complexity: int = 1 # Minimum operator depth
max_complexity: int = 3 # Maximum operator depth
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
@ -96,8 +98,9 @@ class PropositionalLogicConfig:
assert self.min_vars > 0, "min_vars must be positive"
assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars"
assert self.min_statements > 0, "min_statements must be positive"
assert self.max_statements >= self.min_statements
assert self.max_complexity > 0, "max_complexity must be positive"
assert self.max_statements >= self.min_statements, "max_statements must be >= min_statements"
assert self.min_complexity > 0, "min_complexity must be positive"
assert self.max_complexity >= self.min_complexity, "max_complexity must be >= min_complexity"
class Expression:
@ -217,6 +220,11 @@ class PropositionalLogicDataset(ProceduralDataset):
"variables": variables,
"complexity": self._measure_complexity(conclusion),
"example_answer": str(conclusion),
"difficulty": {
"vars": num_vars,
"statements": num_statements,
"complexity": (self.config.min_complexity, self.config.max_complexity),
},
},
}
@ -224,7 +232,7 @@ class PropositionalLogicDataset(ProceduralDataset):
"""Generate a list of premise statements"""
premises = []
for _ in range(num_statements):
depth = rng.randint(1, self.config.max_complexity)
depth = rng.randint(self.config.min_complexity, self.config.max_complexity)
premises.append(self._generate_expression(rng, variables, depth))
return premises
@ -329,4 +337,45 @@ class PropositionalLogicDataset(ProceduralDataset):
return True
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)
class PropositionalLogicCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(PropositionalLogicCurriculum.__name__, PropositionalLogicConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="vars",
levels=[2, 4, 6, 8, 10],
default_level=0,
description="Number of variables in the logical expressions",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_vars",
upper_field_name="max_vars",
),
RangeAttributeDefinition(
name="statements",
levels=[2, 4, 6, 8, 10],
default_level=0,
description="Number of premises in the logical expressions",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_statements",
upper_field_name="max_statements",
),
RangeAttributeDefinition(
name="complexity",
levels=[1, 2, 3, 4, 5],
default_level=0,
description="Complexity of the logical expressions",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_complexity",
upper_field_name="max_complexity",
),
)
register_dataset(
"propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig, PropositionalLogicCurriculum
)