diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index 3a840f49..6b5d24dc 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -46,23 +46,32 @@ class LCMDataset: self._current_idx += 1 return item - def _generate_numbers(self, rng: Random) -> List[int]: - """Generate a list of random positive integers""" + def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: + """Generate a list of random positive integers and their LCM. + Will try up to 3 times to find numbers with LCM < product.""" + def calculate_product(nums: List[int]) -> int: + return reduce(lambda x, y: x * y, nums) + + for _ in range(3): # Try up to 3 times to get LCM < product + num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) + numbers = [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_count)] + result = reduce(lcm, numbers) + if result < calculate_product(numbers): + return numbers, result + + # If we failed to find LCM < product after 3 tries, generate one final set num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - return [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] - - def _calculate_lcm(self, numbers: List[int]) -> int: - """Calculate the LCM of a list of numbers""" - return reduce(lcm, numbers) + numbers = [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_count)] + result = reduce(lcm, numbers) + return numbers, result def __getitem__(self, idx: int) -> dict: """Generate a single LCM task""" rng = Random(self.seed + idx) - numbers = self._generate_numbers(rng) - result = self._calculate_lcm(numbers) - + numbers, result = self._generate_numbers(rng) numbers_str = ", ".join(str(n) for n in numbers) return {