BF Curricula and More (#309)

* bf curricula
* modulo grid curricula
* minor changes to how difficulty is stored

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
Rich Jones 2025-03-09 18:22:22 +01:00 committed by GitHub
parent 7c7c783883
commit 46013e4640
5 changed files with 122 additions and 7 deletions

View file

@ -3,6 +3,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -134,9 +135,72 @@ class ModuloGridDataset(ProceduralDataset):
return {
"question": question,
"answer": flatten_grid(grid),
"metadata": {"divisor": divisor, "target": target, "operation": operation},
"metadata": {
"divisor": divisor,
"target": target,
"operation": operation,
"difficulty": {
"holes": self.config.max_holes,
"size_x": self.config.size_x,
"size_y": self.config.size_y,
},
},
}
class ModuloGridCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ModuloGridCurriculum.__name__, ModuloGridConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="size_x",
field_name="size_x",
levels=[20, 30, 50, 75],
default_level=0,
description="Size x",
attr_type=AttributeType.STATIC,
min_value=20,
),
ScalarAttributeDefinition(
name="size_y",
field_name="size_y",
levels=[20, 30, 50, 75],
default_level=0,
description="Size y",
attr_type=AttributeType.STATIC,
min_value=20,
),
ScalarAttributeDefinition(
name="max_holes",
field_name="max_holes",
levels=[1, 2, 3, 5],
default_level=0,
description="Max holes",
attr_type=AttributeType.STATIC,
min_value=1,
),
ScalarAttributeDefinition(
name="max_divisor",
field_name="max_divisor",
levels=[9, 10, 11, 48],
default_level=0,
description="Max divisor",
attr_type=AttributeType.STATIC,
min_value=1,
),
ScalarAttributeDefinition(
name="max_target",
field_name="max_target",
levels=[7, 14, 21, 49],
default_level=0,
description="Max target",
attr_type=AttributeType.STATIC,
min_value=1,
),
)
# Register the dataset
register_dataset("modulo_grid", ModuloGridDataset, ModuloGridConfig)
register_dataset("modulo_grid", ModuloGridDataset, ModuloGridConfig, ModuloGridCurriculum)