mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
Template Shared Attr
This commit is contained in:
parent
c2e3dbc826
commit
9a5841c6c2
4 changed files with 307 additions and 391 deletions
|
|
@ -5,7 +5,7 @@ Curriculum definition for the ChainSum exercise.
|
|||
from typing import Dict, Any
|
||||
from reasoning_gym.core.base_curriculum import BaseCurriculum
|
||||
from reasoning_gym.core.attributes import AttributeDefinition, AttributeType
|
||||
from reasoning_gym.core.template import Template, Placeholder
|
||||
from reasoning_gym.core.template import Template
|
||||
|
||||
# TODO: Brackets
|
||||
class ChainSumCurriculum(BaseCurriculum):
|
||||
|
|
@ -64,24 +64,23 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
)
|
||||
}
|
||||
|
||||
# Define templates with symbolic placeholders
|
||||
expression = Placeholder("expression", "symbolic_expression")
|
||||
# Define templates using the new system
|
||||
self._templates = [
|
||||
Template(
|
||||
question="What is {expression} ?",
|
||||
placeholders={"expression": expression}
|
||||
template="What is {expression} ?",
|
||||
parts={"expression": "symbolic_expression"}
|
||||
),
|
||||
Template(
|
||||
question="Calculate the following: {expression}",
|
||||
placeholders={"expression": expression}
|
||||
template="Calculate the following: {expression}",
|
||||
parts={"expression": "symbolic_expression"}
|
||||
),
|
||||
Template(
|
||||
question="Solve {expression}",
|
||||
placeholders={"expression": expression}
|
||||
template="Solve {expression}",
|
||||
parts={"expression": "symbolic_expression"}
|
||||
),
|
||||
Template(
|
||||
question="{expression} = ?",
|
||||
placeholders={"expression": expression}
|
||||
template="{expression} = ?",
|
||||
parts={"expression": "symbolic_expression"}
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -89,7 +88,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
self._symbolic = {
|
||||
# Define composition templates
|
||||
"templates": {
|
||||
# Expression structure - this key matches the generator name in Placeholder
|
||||
# Expression structure
|
||||
"symbolic_expression": lambda refs: (
|
||||
n_terms := refs["num_terms"](),
|
||||
{
|
||||
|
|
@ -97,7 +96,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
for i in range(n_terms - 1)),
|
||||
"parts": {
|
||||
**{f"term_{i}": "term" for i in range(n_terms)},
|
||||
**{f"op_{i}": lambda refs=refs: refs["operators"]()
|
||||
**{f"op_{i}": lambda refs=refs: refs["operator"](refs)()
|
||||
for i in range(n_terms - 1)}
|
||||
}
|
||||
}
|
||||
|
|
@ -107,25 +106,14 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
"term": lambda refs: {
|
||||
"template": "{sign}{value}",
|
||||
"parts": {
|
||||
"sign": lambda refs=refs: refs["sign"](),
|
||||
"value": "notation"
|
||||
"sign": lambda refs=refs: refs["sign_term"](refs)(),
|
||||
"value": lambda refs=refs: refs["format_number"](refs)(refs["number"](refs)())
|
||||
}
|
||||
},
|
||||
|
||||
# Notation structure
|
||||
"notation": lambda refs: {
|
||||
"template": {
|
||||
"regular": str(refs["number"](refs)()),
|
||||
"scientific": f"{float(refs['number'](refs)()):e}",
|
||||
"base2": f"0b{int(refs['number'](refs)()):b}",
|
||||
"base16": f"0x{int(refs['number'](refs)()):X}"
|
||||
}[refs["notation"]()],
|
||||
"parts": {}
|
||||
}
|
||||
},
|
||||
|
||||
# Define value generators
|
||||
"generators": {
|
||||
# Generate a number based on current settings
|
||||
"number": lambda refs: (
|
||||
lambda: (
|
||||
max_val := (10 ** refs["num_digits"]()) - 1,
|
||||
|
|
@ -134,6 +122,17 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
if refs["num_decimals"]() > 0 and refs["notation"]() in ["regular", "scientific"]
|
||||
else base_num
|
||||
)[-1]
|
||||
)
|
||||
}
|
||||
),
|
||||
# Generate an operator from available options
|
||||
"operator": lambda refs: lambda: refs["operators"](),
|
||||
# Generate a sign based on current settings
|
||||
"sign_term": lambda refs: lambda: refs["sign"](),
|
||||
# Format a number according to notation
|
||||
"format_number": lambda refs: lambda value: {
|
||||
"regular": str(value),
|
||||
"scientific": f"{float(value):e}",
|
||||
"base2": f"0b{int(value):b}",
|
||||
"base16": f"0x{int(value):X}"
|
||||
}[refs["notation"]()]
|
||||
},
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue