diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index 882e3eff..0465d212 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -46,23 +46,29 @@ class GCDDataset: self._current_idx += 1 return item - def _generate_numbers(self, rng: Random) -> List[int]: - """Generate a list of random positive integers""" + def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: + """Generate a list of random positive integers and their GCD. + Will try up to 3 times to find numbers with GCD > 1.""" + for _ in range(3): # Try up to 3 times to get GCD > 1 + num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) + numbers = [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_count)] + result = reduce(gcd, numbers) + if result > 1: + return numbers, result + + # If we failed to find GCD > 1 after 3 tries, generate one final set num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - return [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] - - def _calculate_gcd(self, numbers: List[int]) -> int: - """Calculate the GCD of a list of numbers""" - return reduce(gcd, numbers) + numbers = [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_count)] + result = reduce(gcd, numbers) + return numbers, result def __getitem__(self, idx: int) -> dict: """Generate a single GCD task""" rng = Random(self.seed + idx) - numbers = self._generate_numbers(rng) - result = self._calculate_gcd(numbers) - + numbers, result = self._generate_numbers(rng) numbers_str = ", ".join(str(n) for n in numbers) return {