Template Shared Attr

This commit is contained in:
EduardDurech 2025-02-07 22:50:55 +00:00
parent c2e3dbc826
commit 9a5841c6c2
4 changed files with 307 additions and 391 deletions

View file

@ -5,7 +5,7 @@ Curriculum definition for the ChainSum exercise.
from typing import Dict, Any
from reasoning_gym.core.base_curriculum import BaseCurriculum
from reasoning_gym.core.attributes import AttributeDefinition, AttributeType
from reasoning_gym.core.template import Template, Placeholder
from reasoning_gym.core.template import Template
# TODO: Brackets
class ChainSumCurriculum(BaseCurriculum):
@ -64,24 +64,23 @@ class ChainSumCurriculum(BaseCurriculum):
)
}
# Define templates with symbolic placeholders
expression = Placeholder("expression", "symbolic_expression")
# Define templates using the new system
self._templates = [
Template(
question="What is {expression} ?",
placeholders={"expression": expression}
template="What is {expression} ?",
parts={"expression": "symbolic_expression"}
),
Template(
question="Calculate the following: {expression}",
placeholders={"expression": expression}
template="Calculate the following: {expression}",
parts={"expression": "symbolic_expression"}
),
Template(
question="Solve {expression}",
placeholders={"expression": expression}
template="Solve {expression}",
parts={"expression": "symbolic_expression"}
),
Template(
question="{expression} = ?",
placeholders={"expression": expression}
template="{expression} = ?",
parts={"expression": "symbolic_expression"}
)
]
@ -89,7 +88,7 @@ class ChainSumCurriculum(BaseCurriculum):
self._symbolic = {
# Define composition templates
"templates": {
# Expression structure - this key matches the generator name in Placeholder
# Expression structure
"symbolic_expression": lambda refs: (
n_terms := refs["num_terms"](),
{
@ -97,7 +96,7 @@ class ChainSumCurriculum(BaseCurriculum):
for i in range(n_terms - 1)),
"parts": {
**{f"term_{i}": "term" for i in range(n_terms)},
**{f"op_{i}": lambda refs=refs: refs["operators"]()
**{f"op_{i}": lambda refs=refs: refs["operator"](refs)()
for i in range(n_terms - 1)}
}
}
@ -107,25 +106,14 @@ class ChainSumCurriculum(BaseCurriculum):
"term": lambda refs: {
"template": "{sign}{value}",
"parts": {
"sign": lambda refs=refs: refs["sign"](),
"value": "notation"
"sign": lambda refs=refs: refs["sign_term"](refs)(),
"value": lambda refs=refs: refs["format_number"](refs)(refs["number"](refs)())
}
},
# Notation structure
"notation": lambda refs: {
"template": {
"regular": str(refs["number"](refs)()),
"scientific": f"{float(refs['number'](refs)()):e}",
"base2": f"0b{int(refs['number'](refs)()):b}",
"base16": f"0x{int(refs['number'](refs)()):X}"
}[refs["notation"]()],
"parts": {}
}
},
# Define value generators
"generators": {
# Generate a number based on current settings
"number": lambda refs: (
lambda: (
max_val := (10 ** refs["num_digits"]()) - 1,
@ -134,6 +122,17 @@ class ChainSumCurriculum(BaseCurriculum):
if refs["num_decimals"]() > 0 and refs["notation"]() in ["regular", "scientific"]
else base_num
)[-1]
)
}
),
# Generate an operator from available options
"operator": lambda refs: lambda: refs["operators"](),
# Generate a sign based on current settings
"sign_term": lambda refs: lambda: refs["sign"](),
# Format a number according to notation
"format_number": lambda refs: lambda value: {
"regular": str(value),
"scientific": f"{float(value):e}",
"base2": f"0b{int(value):b}",
"base16": f"0x{int(value):X}"
}[refs["notation"]()]
},
}