mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
add ProductsDataset (multiplication tasks)
This commit is contained in:
parent
17485fad67
commit
1996ffa6d8
10 changed files with 56 additions and 56 deletions
|
|
@ -32,7 +32,7 @@ class ChainSumConfig:
|
|||
assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range"
|
||||
|
||||
|
||||
class ChainSum(ProceduralDataset):
|
||||
class ChainSumDataset(ProceduralDataset):
|
||||
"""Generates simple arithmetic tasks using only + and - operators"""
|
||||
|
||||
def __init__(self, config: ChainSumConfig):
|
||||
|
|
@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset):
|
|||
- metadata: dict with generation parameters
|
||||
"""
|
||||
# Create deterministic RNG from base seed and idx
|
||||
item_rng = random.Random(self.seed + idx)
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
|
||||
|
||||
# Calculate value ranges based on number of digits
|
||||
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
|
||||
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
|
||||
|
||||
expression, result = self._generate_task(item_rng, num_terms, min_value, max_value)
|
||||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||
|
||||
return {
|
||||
"question": f"{expression} =",
|
||||
|
|
@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue