rotten oranges curriculum (#297)

This commit is contained in:
Zafir Stojanovski 2025-03-08 20:56:46 +01:00 committed by GitHub
parent e4e516a949
commit cbd223d7ca
3 changed files with 54 additions and 8 deletions

View file

@ -29,7 +29,7 @@ from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePar
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
from .ransom_note import RansomNoteConfig, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
@ -116,6 +116,7 @@ __all__ = [
"StringSynthesisDataset",
"RottenOrangesConfig",
"RottenOrangesDataset",
"RottenOrangesCurriculum",
"JugsConfig",
"JugsDataset",
"BinaryAlternationConfig",

View file

@ -9,6 +9,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """You are given an n x n grid where each cell can have one of three values:
@ -40,7 +41,7 @@ class RottenOrangesConfig:
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.min_n, "min_n must be at least 1"
assert 2 <= self.min_n, "min_n must be at least 2"
assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n"
assert 0 < self.p_oranges <= 1, "p_oranges must be between 0 and 1"
assert 0 < self.p_rotten <= 1, "p_rotten must be between 0 and 1"
@ -56,9 +57,8 @@ class RottenOrangesDataset(ProceduralDataset):
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def _get_initial_matrix(self, rng: Random) -> list[list[int]]:
def _get_initial_matrix(self, rng: Random, n: int) -> list[list[int]]:
"""Generate a random matrix with oranges"""
n = rng.randint(self.config.min_n, self.config.max_n)
matrix = [[0] * n for _ in range(n)]
for i in range(n):
for j in range(n):
@ -111,15 +111,39 @@ class RottenOrangesDataset(ProceduralDataset):
"""Generate a single Rotten Oranges question"""
rng = Random(self.seed + idx)
matrix = self._get_initial_matrix(rng)
n = rng.randint(self.config.min_n, self.config.max_n)
matrix = self._get_initial_matrix(rng, n)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_answer(matrix)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"answer": str(answer),
"metadata": {"matrix": matrix, "solution": answer},
"metadata": {
"matrix": matrix,
"solution": answer,
"difficulty": {"n": n},
},
}
register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig)
class RottenOrangesCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(RottenOrangesCurriculum.__name__, RottenOrangesConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[10, 25, 50, 100],
default_level=0,
description="Size of the square matrix",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_n",
upper_field_name="max_n",
)
)
register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum)