fix: Ensure simplified fraction values stay within configured bounds

This commit is contained in:
Andreas Koepf (aider) 2025-01-24 09:11:07 +01:00
parent 1220118d95
commit 7cf3eb5f26

View file

@ -48,21 +48,33 @@ class FractionSimplificationDataset:
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
"""Generate a random fraction and its simplified form.
Returns (numerator, denominator, simplified_num, simplified_den)"""
# Generate the simplified fraction first
# Try to generate valid fractions until we get one that meets our criteria
for _ in range(10): # Limit attempts to avoid infinite loop
# Generate the simplified fraction first
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Make sure they're coprime by dividing by their GCD
common = gcd(simplified_num, simplified_den)
simplified_num //= common
simplified_den //= common
# Check if simplified fraction is within bounds
if (self.config.min_value <= simplified_num <= self.config.max_value and
self.config.min_value <= simplified_den <= self.config.max_value):
# Multiply both by a random factor to create the unsimplified version
factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor
denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den
# If we failed to find a good fraction after max attempts,
# generate one that's guaranteed to be within bounds
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Make sure they're coprime by dividing by their GCD
common = gcd(simplified_num, simplified_den)
simplified_num //= common
simplified_den //= common
# Multiply both by a random factor to create the unsimplified version
factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor
denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den
return (simplified_num * factor, simplified_den * factor,
simplified_num, simplified_den)
def __getitem__(self, idx: int) -> dict:
"""Generate a single fraction simplification task"""