mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat(env): String Synthesis Curriculum (#308)
* string synthesis curriculum * difficulty metadata
This commit is contained in:
parent
037905667e
commit
a1dc28aa73
3 changed files with 63 additions and 5 deletions
|
|
@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """There are nine different blocks [A] [B] [C] {{A}} {{B}} {{C}} (A) (B) (C)
|
||||
|
|
@ -24,6 +25,7 @@ The output should be the count of each block type after the rules have been appl
|
|||
For example 1 0 3 0 2 0 0 0 1 means that you have 1 [A] 0 [B] 3 [C] 0 {{A}} 2 {{B}} 0 {{C}} 0 (A) 0 (B) 1 (C).
|
||||
|
||||
Now, you have {A_square} [A], {B_square} [B], and {C_square} [C] blocks. Provide the count of each block type after applying the above rules.
|
||||
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts.
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -120,10 +122,40 @@ class StringSynthesisDataset(ProceduralDataset):
|
|||
answer_str = " ".join(str(x) for x in answer)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(A_square=A_square, B_square=B_square, C_square=C_square),
|
||||
"question": QUESTION_TEMPLATE.format(
|
||||
A_square=A_square,
|
||||
B_square=B_square,
|
||||
C_square=C_square,
|
||||
max_iterations=self.config.max_iterations,
|
||||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {"states": states, "solution": answer},
|
||||
"metadata": {
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"difficulty": {
|
||||
"initial_blocks": (A_square, B_square, C_square),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig)
|
||||
class StringSynthesisCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(StringSynthesisCurriculum.__name__, StringSynthesisConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="initial_blocks",
|
||||
levels=[10, 50, 100, 500],
|
||||
default_level=1,
|
||||
description="Number of initial blocks",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_initial_blocks",
|
||||
upper_field_name="max_initial_blocks",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue