Add jugs curriculum (#369)

This commit is contained in:
Adefioye 2025-03-14 12:04:33 -05:00 committed by GitHub
parent 8c12fe86e2
commit 8a0cacc054
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 87 additions and 4 deletions

View file

@ -6,6 +6,7 @@ from functools import reduce
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -280,7 +281,13 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
return {
"question": question,
"answer": json.dumps(solution), # one possible solution
"metadata": {"puzzle": puzzle},
"metadata": {
"puzzle": puzzle,
"difficulty": {
"num_jugs": self.config.num_jugs,
"difficulty": self.config.difficulty,
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -310,4 +317,33 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
return 0.0
register_dataset("jugs", JugsDataset, JugsConfig)
class JugsCurriculum(BaseCurriculum):
"""Curriculum for Jugs puzzles"""
def __init__(self):
super().__init__(JugsCurriculum.__name__, JugsConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="num_jugs",
field_name="num_jugs",
levels=[3, 4, 5, 7],
default_level=0,
description="Number of jugs in the puzzle",
attr_type=AttributeType.STATIC,
min_value=3,
),
ScalarAttributeDefinition(
name="difficulty",
field_name="difficulty",
levels=[2, 4, 6, 8],
default_level=0,
description="Minimum required moves to solve the puzzle",
attr_type=AttributeType.STATIC,
min_value=10,
),
)
register_dataset("jugs", JugsDataset, JugsConfig, JugsCurriculum)