diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index d04c8fa6..0002edcd 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -1,7 +1,7 @@ """Fraction simplification task generator""" from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import Optional, Tuple, Sequence from math import gcd @@ -12,7 +12,7 @@ class FractionSimplificationConfig: max_value: int = 100 # Maximum value for numerator/denominator min_factor: int = 2 # Minimum multiplication factor max_factor: int = 10 # Maximum multiplication factor - styles: List[str] = None # List of allowed fraction formatting styles + styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -23,10 +23,6 @@ class FractionSimplificationConfig: assert self.min_factor >= 2, "min_factor must be at least 2" assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" - # Set default styles if none provided - if self.styles is None: - self.styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"] - # Validate styles valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"} for style in self.styles: @@ -140,7 +136,7 @@ def fraction_simplification_dataset( max_value: int = 100, min_factor: int = 2, max_factor: int = 10, - styles: List[str] = None, + styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"), seed: Optional[int] = None, size: int = 500, ) -> FractionSimplificationDataset: