feat(env): String Synthesis Curriculum (#308)

* string synthesis curriculum

* difficulty metadata
This commit is contained in:
Zafir Stojanovski 2025-03-10 00:27:03 +01:00 committed by GitHub
parent 037905667e
commit a1dc28aa73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 63 additions and 5 deletions

View file

@ -7,6 +7,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """There are nine different blocks [A] [B] [C] {{A}} {{B}} {{C}} (A) (B) (C)
@ -24,6 +25,7 @@ The output should be the count of each block type after the rules have been appl
For example 1 0 3 0 2 0 0 0 1 means that you have 1 [A] 0 [B] 3 [C] 0 {{A}} 2 {{B}} 0 {{C}} 0 (A) 0 (B) 1 (C).
Now, you have {A_square} [A], {B_square} [B], and {C_square} [C] blocks. Provide the count of each block type after applying the above rules.
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts.
"""
@ -120,10 +122,40 @@ class StringSynthesisDataset(ProceduralDataset):
answer_str = " ".join(str(x) for x in answer)
return {
"question": QUESTION_TEMPLATE.format(A_square=A_square, B_square=B_square, C_square=C_square),
"question": QUESTION_TEMPLATE.format(
A_square=A_square,
B_square=B_square,
C_square=C_square,
max_iterations=self.config.max_iterations,
),
"answer": answer_str,
"metadata": {"states": states, "solution": answer},
"metadata": {
"states": states,
"solution": answer,
"difficulty": {
"initial_blocks": (A_square, B_square, C_square),
},
},
}
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig)
class StringSynthesisCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(StringSynthesisCurriculum.__name__, StringSynthesisConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="initial_blocks",
levels=[10, 50, 100, 500],
default_level=1,
description="Number of initial blocks",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_initial_blocks",
upper_field_name="max_initial_blocks",
)
)
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)