string splitting curriculum (#307)

This commit is contained in:
Zafir Stojanovski 2025-03-10 00:25:56 +01:00 committed by GitHub
parent 83cd34e21b
commit 037905667e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 64 additions and 6 deletions

View file

@ -7,6 +7,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """There is a dismantling engineer who has old machines A, B, and C.
@ -24,6 +25,7 @@ The output should be the count of each machine and part type after the rules hav
For example 1 0 1 5 4 3 means that you have 1 machine A, 0 machine B, 1 machine C, 5 part X, 4 part Y, and 3 part Z.
Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} machine C. Provide the count of each machine and part type after applying the above rules.
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type.
"""
@ -115,10 +117,40 @@ class StringSplittingDataset(ProceduralDataset):
answer_str = " ".join(str(x) for x in answer)
return {
"question": QUESTION_TEMPLATE.format(A_machine=A_machine, B_machine=B_machine, C_machine=C_machine),
"question": QUESTION_TEMPLATE.format(
A_machine=A_machine,
B_machine=B_machine,
C_machine=C_machine,
max_iterations=self.config.max_iterations,
),
"answer": answer_str,
"metadata": {"states": states, "solution": answer},
"metadata": {
"states": states,
"solution": answer,
"difficulty": {
"initial_machines": (A_machine, B_machine, C_machine),
},
},
}
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig)
class StringSplittingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(StringSplittingCurriculum.__name__, StringSplittingConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="initial_machines",
levels=[10, 50, 100, 500],
default_level=1,
description="Number of initial machines",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_initial_machines",
upper_field_name="max_initial_machines",
)
)
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)