diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 01d387a6..969df820 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -119,28 +119,26 @@ class ChainSumCurriculum(BaseCurriculum): # Define attributes self._define_attributes( - ( - RangeAttributeDefinition( - name="num_terms", - levels=[2, 3, 4, 5], - default_level=0, # Start with 2 terms - description="Maximum number of terms in the expression", - attr_type=AttributeType.APPEND, - min_value=2, # Ensure at least 2 terms - lower_field_name="min_terms", - upper_field_name="max_terms", - ), - RangeAttributeDefinition( - name="num_digits", - levels=[1, 2, 4, 10], - default_level=0, # Start with 1-digit numbers - description="Number of digits in each operand", - attr_type=AttributeType.APPEND, - min_value=1, # Ensure numbers are at least 1 digit - lower_field_name="min_digits", - upper_field_name="max_digits", - ), - ) + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 4, 10], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), ) diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py index 8d619869..b3e97672 100644 --- a/reasoning_gym/coaching/base_curriculum.py +++ b/reasoning_gym/coaching/base_curriculum.py @@ -34,7 +34,7 @@ class BaseCurriculum: raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist") return self._attributes[attr_name] - def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None: + def _define_attributes(self, *attrs: tuple[AttributeDefinition, ...]) -> None: for attr in attrs: if attr.name in self.attributes: raise RuntimeError(f"Attribute with name {attr.name} is already defined.")