mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
string splitting curriculum (#307)
This commit is contained in:
parent
83cd34e21b
commit
037905667e
3 changed files with 64 additions and 6 deletions
|
|
@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """There is a dismantling engineer who has old machines A, B, and C.
|
||||
|
|
@ -24,6 +25,7 @@ The output should be the count of each machine and part type after the rules hav
|
|||
For example 1 0 1 5 4 3 means that you have 1 machine A, 0 machine B, 1 machine C, 5 part X, 4 part Y, and 3 part Z.
|
||||
|
||||
Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} machine C. Provide the count of each machine and part type after applying the above rules.
|
||||
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type.
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -115,10 +117,40 @@ class StringSplittingDataset(ProceduralDataset):
|
|||
answer_str = " ".join(str(x) for x in answer)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(A_machine=A_machine, B_machine=B_machine, C_machine=C_machine),
|
||||
"question": QUESTION_TEMPLATE.format(
|
||||
A_machine=A_machine,
|
||||
B_machine=B_machine,
|
||||
C_machine=C_machine,
|
||||
max_iterations=self.config.max_iterations,
|
||||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {"states": states, "solution": answer},
|
||||
"metadata": {
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"difficulty": {
|
||||
"initial_machines": (A_machine, B_machine, C_machine),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig)
|
||||
class StringSplittingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(StringSplittingCurriculum.__name__, StringSplittingConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="initial_machines",
|
||||
levels=[10, 50, 100, 500],
|
||||
default_level=1,
|
||||
description="Number of initial machines",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_initial_machines",
|
||||
upper_field_name="max_initial_machines",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue