diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 0a9d1849..581dae37 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -76,21 +76,57 @@ class FractionSimplificationDataset: return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den) + def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: + """Format a fraction in various styles""" + if style == "plain": + return f"{num}/{den}" + elif style == "latex_inline": + return f"${num}/{den}$" + elif style == "latex_frac": + return f"$\\frac{{{num}}}{{{den}}}$" + elif style == "latex_dfrac": + return f"$\\dfrac{{{num}}}{{{den}}}$" + else: + raise ValueError(f"Unknown fraction style: {style}") + def __getitem__(self, idx: int) -> dict: """Generate a single fraction simplification task""" rng = Random(self.seed + idx) num, den, simple_num, simple_den = self._generate_fraction(rng) + # Choose a random style for this question + styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"] + style = styles[rng.randint(0, len(styles)-1)] + + # Format both question and answer in the same style + question_fraction = self._format_fraction(num, den, style) + answer_fraction = self._format_fraction(simple_num, simple_den, style) + return { - "question": f"Simplify the fraction {num}/{den} to its lowest terms", - "answer": f"{simple_num}/{simple_den}", + "question": f"Simplify the fraction {question_fraction} to its lowest terms", + "answer": answer_fraction, "metadata": { "numerator": num, "denominator": den, "simplified_numerator": simple_num, "simplified_denominator": simple_den, - "reduction_factor": num // simple_num # Will be same as den // simple_den + "reduction_factor": num // simple_num, # Will be same as den // simple_den + "style": style, + "formats": { + "question": { + "plain": self._format_fraction(num, den, "plain"), + "latex_inline": self._format_fraction(num, den, "latex_inline"), + "latex_frac": self._format_fraction(num, den, "latex_frac"), + "latex_dfrac": self._format_fraction(num, den, "latex_dfrac") + }, + "answer": { + "plain": self._format_fraction(simple_num, simple_den, "plain"), + "latex_inline": self._format_fraction(simple_num, simple_den, "latex_inline"), + "latex_frac": self._format_fraction(simple_num, simple_den, "latex_frac"), + "latex_dfrac": self._format_fraction(simple_num, simple_den, "latex_dfrac") + } + } } }