mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
use *args param for _define_attributes()
This commit is contained in:
parent
c9b39fdab1
commit
ab9f781d97
2 changed files with 21 additions and 23 deletions
|
|
@ -119,28 +119,26 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2, # Ensure at least 2 terms
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure numbers are at least 1 digit
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
)
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2, # Ensure at least 2 terms
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure numbers are at least 1 digit
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class BaseCurriculum:
|
|||
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
|
||||
return self._attributes[attr_name]
|
||||
|
||||
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
|
||||
def _define_attributes(self, *attrs: tuple[AttributeDefinition, ...]) -> None:
|
||||
for attr in attrs:
|
||||
if attr.name in self.attributes:
|
||||
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue