mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
number filtering curriculum (#333)
This commit is contained in:
parent
f14662e213
commit
aa6ccf1946
3 changed files with 81 additions and 3 deletions
|
|
@ -22,7 +22,7 @@ from .jugs import JugsConfig, JugsDataset
|
|||
from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset
|
||||
from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset
|
||||
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset
|
||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
from .number_filtering import NumberFilteringConfig, NumberFilteringCurriculum, NumberFilteringDataset
|
||||
from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
|
||||
from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
|
||||
from .palindrome_partitioning import (
|
||||
|
|
@ -73,6 +73,7 @@ __all__ = [
|
|||
"LetterJumbleCurriculum",
|
||||
"NumberFilteringConfig",
|
||||
"NumberFilteringDataset",
|
||||
"NumberFilteringCurriculum",
|
||||
"NumberSortingConfig",
|
||||
"NumberSortingDataset",
|
||||
"NumberSortingCurriculum",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -94,8 +95,52 @@ class NumberFilteringDataset(ProceduralDataset):
|
|||
"filter_value": filter_str,
|
||||
"operation": f"{keep_remove}_{larger_smaller}",
|
||||
"result": result_strs,
|
||||
"difficulty": {
|
||||
"numbers": len(numbers),
|
||||
"decimals": (self.config.min_decimals, self.config.max_decimals),
|
||||
"value": (self.config.min_value, self.config.max_value),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig)
|
||||
class NumberFilteringCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(NumberFilteringCurriculum.__name__, NumberFilteringConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="numbers",
|
||||
levels=[10, 100, 500, 1000],
|
||||
default_level=1,
|
||||
description="How many numbers to sort",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_numbers",
|
||||
upper_field_name="max_numbers",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="decimals",
|
||||
levels=[0, 2, 4, 6],
|
||||
default_level=1,
|
||||
description="Number of decimal places",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_decimals",
|
||||
upper_field_name="max_decimals",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[-10_000, 10_000],
|
||||
default_level=1,
|
||||
description="Range of numbers to sort",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=-10_000,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue