mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Modify LCM dataset to generate numbers with LCM less than their product
This commit is contained in:
parent
e2957a744d
commit
387740b9bd
1 changed files with 20 additions and 11 deletions
|
|
@ -46,23 +46,32 @@ class LCMDataset:
|
|||
self._current_idx += 1
|
||||
return item
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> List[int]:
|
||||
"""Generate a list of random positive integers"""
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their LCM.
|
||||
Will try up to 3 times to find numbers with LCM < product."""
|
||||
def calculate_product(nums: List[int]) -> int:
|
||||
return reduce(lambda x, y: x * y, nums)
|
||||
|
||||
for _ in range(3): # Try up to 3 times to get LCM < product
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
result = reduce(lcm, numbers)
|
||||
if result < calculate_product(numbers):
|
||||
return numbers, result
|
||||
|
||||
# If we failed to find LCM < product after 3 tries, generate one final set
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
return [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
|
||||
def _calculate_lcm(self, numbers: List[int]) -> int:
|
||||
"""Calculate the LCM of a list of numbers"""
|
||||
return reduce(lcm, numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
result = reduce(lcm, numbers)
|
||||
return numbers, result
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single LCM task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers = self._generate_numbers(rng)
|
||||
result = self._calculate_lcm(numbers)
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue