mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
added Decimal curriculum (#280)
* added decimal curricula * added chain sum decimal curriculum * register DecimalArithmeticCurriculum & DecimalChainSumCurriculum --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
dc657b5ed4
commit
e304b20e24
5 changed files with 178 additions and 16 deletions
|
|
@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -160,4 +161,43 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)
|
||||
class DecimalChainSumCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(DecimalChainSumCurriculum.__name__, DecimalChainSumConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0,
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="decimal_places",
|
||||
levels=[1, 2, 3, 4],
|
||||
default_level=0,
|
||||
description="Number of decimal places in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_decimal_places",
|
||||
upper_field_name="max_decimal_places",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue