diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 4d476bb1..d04c8fa6 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -12,6 +12,7 @@ class FractionSimplificationConfig: max_value: int = 100 # Maximum value for numerator/denominator min_factor: int = 2 # Minimum multiplication factor max_factor: int = 10 # Maximum multiplication factor + styles: List[str] = None # List of allowed fraction formatting styles seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -21,6 +22,15 @@ class FractionSimplificationConfig: assert self.max_value > self.min_value, "max_value must be > min_value" assert self.min_factor >= 2, "min_factor must be at least 2" assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" + + # Set default styles if none provided + if self.styles is None: + self.styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"] + + # Validate styles + valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"} + for style in self.styles: + assert style in valid_styles, f"Invalid style: {style}. Must be one of {valid_styles}" class FractionSimplificationDataset: @@ -104,9 +114,8 @@ class FractionSimplificationDataset: num, den, simple_num, simple_den = self._generate_fraction(rng) - # Choose a random style for this question - styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"] - style = styles[rng.randint(0, len(styles)-1)] + # Choose a random style from configured styles + style = self.config.styles[rng.randint(0, len(self.config.styles)-1)] # Format both question and answer in the same style question_fraction = self._format_fraction(num, den, style) @@ -131,6 +140,7 @@ def fraction_simplification_dataset( max_value: int = 100, min_factor: int = 2, max_factor: int = 10, + styles: List[str] = None, seed: Optional[int] = None, size: int = 500, ) -> FractionSimplificationDataset: @@ -140,6 +150,7 @@ def fraction_simplification_dataset( max_value=max_value, min_factor=min_factor, max_factor=max_factor, + styles=styles, seed=seed, size=size, )