mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
circuit logic curriculum (#368)
This commit is contained in:
parent
3e0254afea
commit
843a277248
3 changed files with 87 additions and 25 deletions
|
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
VERT = "│"
|
||||
|
|
@ -60,7 +61,8 @@ class CircuitLogicConfig:
|
|||
:param seed: Random seed
|
||||
"""
|
||||
|
||||
num_terms: int = 5
|
||||
min_terms: int = 3
|
||||
max_terms: int = 5
|
||||
min_inputs: int = 2
|
||||
max_inputs: int = 4
|
||||
neg_prob: float = 0.3
|
||||
|
|
@ -70,7 +72,7 @@ class CircuitLogicConfig:
|
|||
|
||||
def validate(self):
|
||||
assert 1 <= self.min_inputs <= self.max_inputs, "Invalid input range"
|
||||
assert 1 <= self.num_terms, "Invalid number of terms"
|
||||
assert 1 <= self.min_terms <= self.max_terms, "Invalid number of terms"
|
||||
assert 0.0 <= self.neg_prob <= 1.0, "neg_prob must be between 0 and 1"
|
||||
|
||||
|
||||
|
|
@ -112,28 +114,15 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
("AND", "&"),
|
||||
]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.config.size
|
||||
|
||||
def __iter__(self):
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict[str, Any]:
|
||||
if self._current_idx >= self.config.size:
|
||||
raise StopIteration
|
||||
item = self[self._current_idx]
|
||||
self._current_idx += 1
|
||||
return item
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""
|
||||
Generate one random circuit logic item using ASCII drawing.
|
||||
"""
|
||||
rng = Random(self.seed + idx if self.seed is not None else None)
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
return self._generate_circuit(
|
||||
rng=rng,
|
||||
num_terms=self.config.num_terms,
|
||||
num_terms=num_terms,
|
||||
min_inputs=self.config.min_inputs,
|
||||
max_inputs=self.config.max_inputs,
|
||||
neg_prob=self.config.neg_prob,
|
||||
|
|
@ -397,6 +386,10 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
"term_strings": term_strings,
|
||||
"final_gate": final_gate_name,
|
||||
"inputs": inputs_list,
|
||||
"difficulty": {
|
||||
"terms": num_terms,
|
||||
"inputs": (self.config.min_inputs, self.config.max_inputs),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -411,4 +404,33 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
|
||||
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
||||
class CircuitLogicCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(CircuitLogicCurriculum.__name__, CircuitLogicConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="terms",
|
||||
levels=[3, 5, 10, 20, 30],
|
||||
default_level=1,
|
||||
description="Number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="inputs",
|
||||
levels=[2, 4, 6, 8, 10],
|
||||
default_level=1,
|
||||
description="Number of inputs per term",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_inputs",
|
||||
upper_field_name="max_inputs",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig, CircuitLogicCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue