mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
[wip] more flexible api
This commit is contained in:
parent
46cdfc71cf
commit
19b697c89a
2 changed files with 51 additions and 37 deletions
|
|
@ -9,16 +9,18 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
class DecimalArithmeticDatasetConfig:
|
||||
"""Configuration for decimal arithmetic dataset generation"""
|
||||
|
||||
num_decimal_places: int = 6
|
||||
min_num_decimal_places: int = 6
|
||||
max_num_decimal_places: int = 6
|
||||
terms: int = 6
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.num_decimal_places > 0, "num_decimal_places must be positive"
|
||||
# def validate(self) -> None:
|
||||
# """Validate configuration parameters"""
|
||||
# assert self.num_decimal_places > 0, "num_decimal_places must be positive"
|
||||
|
||||
|
||||
def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
|
||||
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
|
||||
"""
|
||||
Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places.
|
||||
|
||||
|
|
@ -34,24 +36,27 @@ def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
|
|||
if operations is None:
|
||||
operations = ["+", "-", "*", "/"]
|
||||
|
||||
max_integer_part = 10 # Maximum whole number portion before decimal
|
||||
max_value = max_integer_part * (10**num_decimal_places)
|
||||
problem = ""
|
||||
|
||||
problem = None
|
||||
for term in range(0, terms):
|
||||
|
||||
# Generate random numbers with exact decimal places
|
||||
num1 = rng.randint(1, max_value) / (10**num_decimal_places)
|
||||
num2 = rng.randint(1, max_value) / (10**num_decimal_places)
|
||||
# Generate random numbers with exact decimal places
|
||||
ndp1 = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part = 10 # Maximum whole number portion before decimal
|
||||
max_value = max_integer_part * (10**ndp1)
|
||||
num1 = rng.randint(1, max_value) / (10**ndp1)
|
||||
|
||||
# Select random operation
|
||||
op = rng.choice(operations)
|
||||
# Select random operation
|
||||
op = rng.choice(operations)
|
||||
op = op if (term <= terms - 2) else ""
|
||||
|
||||
# Format numbers to ensure exact decimal places
|
||||
formatted_num1 = f"{num1:.{num_decimal_places}f}"
|
||||
formatted_num2 = f"{num2:.{num_decimal_places}f}"
|
||||
# Format numbers to ensure exact decimal places
|
||||
formatted_num1 = f"{num1:.{ndp1}f}"
|
||||
|
||||
problem = f"{formatted_num1} {op} {formatted_num2} = ?"
|
||||
problem = problem + f"{formatted_num1} { op }" + " "
|
||||
|
||||
problem = problem + "= ?"
|
||||
print(problem)
|
||||
return problem
|
||||
|
||||
|
||||
|
|
@ -80,7 +85,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
# Create deterministic RNG from base seed and idx
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
decimal_problem = generate_arithmetic_problem(rng, self.config.num_decimal_places)
|
||||
decimal_problem = generate_arithmetic_problem(
|
||||
rng,
|
||||
self.config.min_num_decimal_places,
|
||||
self.config.max_num_decimal_places,
|
||||
terms=self.config.terms,
|
||||
)
|
||||
answer = eval_floordiv(decimal_problem)
|
||||
|
||||
return {"question": decimal_problem, "answer": answer, "metadata": {}}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue