diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 817005d4..4d476bb1 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -62,6 +62,10 @@ class FractionSimplificationDataset: # Check if simplified fraction is within bounds if (self.config.min_value <= simplified_num <= self.config.max_value and self.config.min_value <= simplified_den <= self.config.max_value): + # Ensure numerator is smaller than denominator + if simplified_num > simplified_den: + simplified_num, simplified_den = simplified_den, simplified_num + # Multiply both by a random factor to create the unsimplified version factor = rng.randint(self.config.min_factor, self.config.max_factor) numerator = simplified_num * factor @@ -72,6 +76,11 @@ class FractionSimplificationDataset: # generate one that's guaranteed to be within bounds simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value) + + # Ensure numerator is smaller than denominator + if simplified_num > simplified_den: + simplified_num, simplified_den = simplified_den, simplified_num + factor = rng.randint(self.config.min_factor, self.config.max_factor) return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)