diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 756961f6..0a9d1849 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -48,21 +48,33 @@ class FractionSimplificationDataset: def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]: """Generate a random fraction and its simplified form. Returns (numerator, denominator, simplified_num, simplified_den)""" - # Generate the simplified fraction first + # Try to generate valid fractions until we get one that meets our criteria + for _ in range(10): # Limit attempts to avoid infinite loop + # Generate the simplified fraction first + simplified_num = rng.randint(self.config.min_value, self.config.max_value) + simplified_den = rng.randint(self.config.min_value, self.config.max_value) + + # Make sure they're coprime by dividing by their GCD + common = gcd(simplified_num, simplified_den) + simplified_num //= common + simplified_den //= common + + # Check if simplified fraction is within bounds + if (self.config.min_value <= simplified_num <= self.config.max_value and + self.config.min_value <= simplified_den <= self.config.max_value): + # Multiply both by a random factor to create the unsimplified version + factor = rng.randint(self.config.min_factor, self.config.max_factor) + numerator = simplified_num * factor + denominator = simplified_den * factor + return numerator, denominator, simplified_num, simplified_den + + # If we failed to find a good fraction after max attempts, + # generate one that's guaranteed to be within bounds simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value) - - # Make sure they're coprime by dividing by their GCD - common = gcd(simplified_num, simplified_den) - simplified_num //= common - simplified_den //= common - - # Multiply both by a random factor to create the unsimplified version factor = rng.randint(self.config.min_factor, self.config.max_factor) - numerator = simplified_num * factor - denominator = simplified_den * factor - - return numerator, denominator, simplified_num, simplified_den + return (simplified_num * factor, simplified_den * factor, + simplified_num, simplified_den) def __getitem__(self, idx: int) -> dict: """Generate a single fraction simplification task"""