mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
feat(env): Number Sorting Curriculum (#321)
* number sorting curriculum * metadata
This commit is contained in:
parent
1f6de829bd
commit
f9fa667d82
3 changed files with 86 additions and 9 deletions
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -46,16 +47,14 @@ Please follow the instruction below:
|
|||
# Reparse to ensure exact decimal representation
|
||||
return f"{float(formatted):.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
|
||||
def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
|
||||
numbers = []
|
||||
number_strs = []
|
||||
|
||||
for _ in range(count):
|
||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
num_str = self._format_number(num, decimals)
|
||||
# Reparse to ensure exact value
|
||||
num = float(num_str)
|
||||
|
|
@ -68,7 +67,8 @@ Please follow the instruction below:
|
|||
"""Generate a single sorting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers, number_strs = self._generate_numbers(rng)
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers, number_strs = self._generate_numbers(rng, count)
|
||||
|
||||
# Generate both ascending and descending answers
|
||||
asc_numbers = sorted(numbers)
|
||||
|
|
@ -88,8 +88,56 @@ Please follow the instruction below:
|
|||
return {
|
||||
"question": question,
|
||||
"answer": str(answer),
|
||||
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
||||
"metadata": {
|
||||
"original_numbers": number_strs,
|
||||
"direction": direction,
|
||||
"sorted_numbers": answer,
|
||||
"difficulty": {
|
||||
"numbers": count,
|
||||
"decimals": (self.config.min_decimals, self.config.max_decimals),
|
||||
"value": (self.config.min_value, self.config.max_value),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig)
|
||||
class NumberSortingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(NumberSortingCurriculum.__name__, NumberSortingConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="numbers",
|
||||
levels=[10, 100, 500, 1000],
|
||||
default_level=1,
|
||||
description="How many numbers to sort",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_numbers",
|
||||
upper_field_name="max_numbers",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="decimals",
|
||||
levels=[0, 2, 4, 6],
|
||||
default_level=1,
|
||||
description="Number of decimal places",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_decimals",
|
||||
upper_field_name="max_decimals",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[-10_000, 10_000],
|
||||
default_level=1,
|
||||
description="Range of numbers to sort",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=-10_000,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue