feat(env): Number Sorting Curriculum (#321)

* number sorting curriculum

* metadata
This commit is contained in:
Zafir Stojanovski 2025-03-11 00:18:20 +01:00 committed by GitHub
parent 1f6de829bd
commit f9fa667d82
3 changed files with 86 additions and 9 deletions

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -46,16 +47,14 @@ Please follow the instruction below:
# Reparse to ensure exact decimal representation
return f"{float(formatted):.{decimals}f}"
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]:
"""Generate list of numbers and their string representations"""
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
numbers = []
number_strs = []
for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
num_str = self._format_number(num, decimals)
# Reparse to ensure exact value
num = float(num_str)
@ -68,7 +67,8 @@ Please follow the instruction below:
"""Generate a single sorting task"""
rng = Random(self.seed + idx)
numbers, number_strs = self._generate_numbers(rng)
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers, number_strs = self._generate_numbers(rng, count)
# Generate both ascending and descending answers
asc_numbers = sorted(numbers)
@ -88,8 +88,56 @@ Please follow the instruction below:
return {
"question": question,
"answer": str(answer),
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
"metadata": {
"original_numbers": number_strs,
"direction": direction,
"sorted_numbers": answer,
"difficulty": {
"numbers": count,
"decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value),
},
},
}
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig)
class NumberSortingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(NumberSortingCurriculum.__name__, NumberSortingConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="numbers",
levels=[10, 100, 500, 1000],
default_level=1,
description="How many numbers to sort",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_numbers",
upper_field_name="max_numbers",
),
RangeAttributeDefinition(
name="decimals",
levels=[0, 2, 4, 6],
default_level=1,
description="Number of decimal places",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_decimals",
upper_field_name="max_decimals",
),
RangeAttributeDefinition(
name="value",
levels=[-10_000, 10_000],
default_level=1,
description="Range of numbers to sort",
attr_type=AttributeType.APPEND,
min_value=-10_000,
lower_field_name="min_value",
upper_field_name="max_value",
),
)
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)