diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index 26c12041..6e761d01 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -39,17 +39,16 @@ class GCDDataset(ProceduralDataset): def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: """Generate a list of random positive integers and their GCD. Will try up to 3 times to find numbers with GCD > 1.""" - for _ in range(3): # Try up to 3 times to get GCD > 1 + + # Try up to 3 times to get GCD > 1 + for _ in range(3): num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] result = reduce(gcd, numbers) if result > 1: - return numbers, result - - # If we failed to find GCD > 1 after 3 tries, generate one final set - num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] - result = reduce(gcd, numbers) + break + + # Return the last generated numbers, whether they met the criteria or not return numbers, result def __getitem__(self, idx: int) -> dict: