[wip] more flexible api

This commit is contained in:
Rich Jones 2025-02-19 03:29:40 +01:00
parent 46cdfc71cf
commit 19b697c89a
2 changed files with 51 additions and 37 deletions

View file

@ -9,16 +9,18 @@ from ..factory import ProceduralDataset, register_dataset
class DecimalArithmeticDatasetConfig:
"""Configuration for decimal arithmetic dataset generation"""
num_decimal_places: int = 6
min_num_decimal_places: int = 6
max_num_decimal_places: int = 6
terms: int = 6
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.num_decimal_places > 0, "num_decimal_places must be positive"
# def validate(self) -> None:
# """Validate configuration parameters"""
# assert self.num_decimal_places > 0, "num_decimal_places must be positive"
def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
"""
Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places.
@ -34,24 +36,27 @@ def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
if operations is None:
operations = ["+", "-", "*", "/"]
max_integer_part = 10 # Maximum whole number portion before decimal
max_value = max_integer_part * (10**num_decimal_places)
problem = ""
problem = None
for term in range(0, terms):
# Generate random numbers with exact decimal places
num1 = rng.randint(1, max_value) / (10**num_decimal_places)
num2 = rng.randint(1, max_value) / (10**num_decimal_places)
# Generate random numbers with exact decimal places
ndp1 = rng.randint(min_num_decimal_places, max_num_decimal_places)
max_integer_part = 10 # Maximum whole number portion before decimal
max_value = max_integer_part * (10**ndp1)
num1 = rng.randint(1, max_value) / (10**ndp1)
# Select random operation
op = rng.choice(operations)
# Select random operation
op = rng.choice(operations)
op = op if (term <= terms - 2) else ""
# Format numbers to ensure exact decimal places
formatted_num1 = f"{num1:.{num_decimal_places}f}"
formatted_num2 = f"{num2:.{num_decimal_places}f}"
# Format numbers to ensure exact decimal places
formatted_num1 = f"{num1:.{ndp1}f}"
problem = f"{formatted_num1} {op} {formatted_num2} = ?"
problem = problem + f"{formatted_num1} { op }" + " "
problem = problem + "= ?"
print(problem)
return problem
@ -80,7 +85,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
# Create deterministic RNG from base seed and idx
rng = Random(self.seed + idx)
decimal_problem = generate_arithmetic_problem(rng, self.config.num_decimal_places)
decimal_problem = generate_arithmetic_problem(
rng,
self.config.min_num_decimal_places,
self.config.max_num_decimal_places,
terms=self.config.terms,
)
answer = eval_floordiv(decimal_problem)
return {"question": decimal_problem, "answer": answer, "metadata": {}}