diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 1ae42e5b..45134332 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -1,157 +1,155 @@ -import json -from pathlib import Path +from dataclasses import dataclass from random import Random - -# more variability -def generate_math_task( - rng: Random, num_terms: int, num_digits: int, op: list[str] = ["+", "-", "*"] -) -> tuple[str, int]: - parts = [] - - def add_terms(remaining: int): - num_left = rng.randint(1, remaining) - num_right = remaining - num_left - - if num_left > 1 and rng.random() > 0.5: - if rng.random() > 0.5: - parts.append("-(") - else: - parts.append("(") - add_terms(num_left) - parts.append(")") - else: - for i in range(num_left): - c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) - parts.append(str(c)) - if i + 1 < num_left: - parts.append(rng.choice(op)) - - if num_right > 0: - parts.append(rng.choice(op)) - add_terms(num_right) - - add_terms(num_terms) - - space_parts = [] - for p in parts: - while rng.random() < 0.15: - space_parts.append(" ") - space_parts.append(p) - - term = " ".join(space_parts) - ground_truth = eval(term) - - return term, ground_truth +from typing import Optional, Literal, Any -def generate_task_file(): - rng = Random(42) - - num_tasks = 100_000 - i = 0 - - output_filename = "math_tasks.jsonl" - file_path = Path(output_filename) - with file_path.open("w", encoding="utf-8") as f: - while i < num_tasks: - num_terms = rng.randint(2, 6) - num_digits = rng.randint(1, 6) - term, ground_truth = generate_math_task( - rng, num_terms=num_terms, num_digits=num_digits - ) - if abs(ground_truth) > 10**8 or abs(ground_truth) < 10: - continue - - question_templates = [ - "{0}", - "{0} =", - "{0} = ?", - "What is {0}?", - "Solve {0}", - ] - - template = rng.choice(question_templates) - formatted_task = template.format(term) - - entry = { - "id": str(i), - "question": formatted_task, - "answer": str(ground_truth), - "num_terms": num_terms, - "num_digits": num_digits, - } - - json.dump(entry, f) - f.write("\n") - i += 1 - - - - -class BasicIntArithmeticTaskConfig: - def __init__( - self, - min_digits: int = 1, - max_digits: int = 5, - min_terms: int = 2, - max_terms: int = 8, - ): - self.min_digits = min_digits - self.max_digits = max_digits - self.min_terms = min_terms - self.max_terms = max_terms - self.operators = ["+", "-"] +@dataclass +class ArithmeticDatasetConfig: + """Configuration for arithmetic dataset generation""" + min_terms: int = 2 + max_terms: int = 6 + min_digits: int = 1 + max_digits: int = 4 + operators: list[str] = ("+" , "-", "*") + allow_parentheses: bool = True + allow_negation: bool = True + seed: Optional[int] = None + size: int = 10000 # Virtual dataset size + format_style: Literal["simple", "natural"] = "simple" def validate(self): - assert self.min_digits > 0 - assert self.max_digits >= self.min_digits - assert self.min_terms > 1 - assert self.max_terms >= self.min_terms - assert len(self.operators) > 0 + """Validate configuration parameters""" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + assert len(self.operators) > 0, "must provide at least one operator" + for op in self.operators: + assert op in ["+", "-", "*"], f"unsupported operator: {op}" -def generate_task(rng: Random, cfg: BasicIntArithmeticTaskConfig) -> str: - num_terms = rng.randint(cfg.min_terms, cfg.max_terms) - num_digits = rng.randint(cfg.min_digits, cfg.max_digits) - constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)] - operators = [rng.choice(cfg.operators) for _ in range(num_terms - 1)] - - buffer = [] - - ground_truth = constants[0] - - buffer.append(f"{constants[0]}") - for i, op in enumerate(operators): - c = constants[i + 1] - buffer.append(op) - buffer.append(f"{c}") - - if op == "+": - ground_truth += c - elif op == "-": - ground_truth -= c +class ArithmeticDataset: + """Dataset that generates arithmetic tasks with configurable complexity""" + + def __init__(self, config: ArithmeticDatasetConfig): + self.config = config + self.config.validate() + self.rng = Random(config.seed) + + def __len__(self) -> int: + return self.config.size + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Generate a single arithmetic task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted arithmetic expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + # Use seed derived from idx for deterministic generation + item_rng = Random(self.rng.randint(0, 2**32) + idx) + + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + + if self.config.allow_parentheses: + expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) else: - RuntimeError("Unsupported operator") + expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) + + question = self._format_question(expression) + + return { + "question": question, + "answer": str(result), + "metadata": { + "num_terms": num_terms, + "num_digits": num_digits, + "expression": expression + } + } - buffer.append(f"") + def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: + """Generate a complex arithmetic task with possible parentheses""" + parts = [] - question_templates = [ - "{0}", - "{0} =", - "{0} = ?", - "What is {0}?", - "Solve {0}", - "Calculate {0}", - # 'evaluate {0}', - # 'do me a favor and calculate {0}', - # 'Give me the result of {0}', - # 'Help me solve this: {0}', - # 'calculator: {0}', - # 'Tell me the result of the following expression {0}', - ] + def add_terms(remaining: int): + num_left = rng.randint(1, remaining) + num_right = remaining - num_left - template = rng.choice(question_templates) - task = " ".join(buffer) - formatted_task = template.format(task) + if num_left > 1 and rng.random() > 0.5 and self.config.allow_parentheses: + if rng.random() > 0.5 and self.config.allow_negation: + parts.append("-(") + else: + parts.append("(") + add_terms(num_left) + parts.append(")") + else: + for i in range(num_left): + c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) + parts.append(str(c)) + if i + 1 < num_left: + parts.append(rng.choice(self.config.operators)) - return formatted_task, str(ground_truth), num_terms, num_digits + if num_right > 0: + parts.append(rng.choice(self.config.operators)) + add_terms(num_right) + + add_terms(num_terms) + + # Add random spaces + space_parts = [] + for p in parts: + while rng.random() < 0.15: + space_parts.append(" ") + space_parts.append(p) + + expression = " ".join(space_parts) + result = eval(expression) # Note: eval is safe here as we control the input + + return expression, result + + def _generate_simple_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: + """Generate a simple linear arithmetic task without parentheses""" + constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)] + operators = [rng.choice(self.config.operators) for _ in range(num_terms - 1)] + + # Build expression and compute result + expression_parts = [] + result = constants[0] + + expression_parts.append(str(constants[0])) + for i, op in enumerate(operators): + c = constants[i + 1] + expression_parts.append(op) + expression_parts.append(str(c)) + + if op == "+": + result += c + elif op == "-": + result -= c + elif op == "*": + result *= c + else: + raise RuntimeError(f"Unsupported operator: {op}") + + expression = " ".join(expression_parts) + return expression, result + + def _format_question(self, expression: str) -> str: + """Format the expression according to config style""" + if self.config.format_style == "simple": + return f"{expression} =" + else: + templates = [ + "What is {0}?", + "Calculate {0}", + "Solve {0}", + "Evaluate the expression: {0}" + ] + return self.rng.choice(templates).format(expression) diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py new file mode 100644 index 00000000..57a794ab --- /dev/null +++ b/tests/test_arithmetic.py @@ -0,0 +1,68 @@ +import pytest +from random import Random +from reasoning_gym.arithmetic.basic_arithmetic import ArithmeticDataset, ArithmeticDatasetConfig + + +def test_arithmetic_dataset_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ArithmeticDatasetConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = ArithmeticDatasetConfig(min_terms=3, max_terms=2) + config.validate() + + with pytest.raises(AssertionError): + config = ArithmeticDatasetConfig(operators=["^"]) # Invalid operator + config.validate() + + +def test_arithmetic_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = ArithmeticDatasetConfig(seed=42, size=10) + dataset1 = ArithmeticDataset(config) + dataset2 = ArithmeticDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_arithmetic_dataset_items(): + """Test basic properties of generated items""" + config = ArithmeticDatasetConfig( + min_terms=2, + max_terms=4, + min_digits=1, + max_digits=2, + size=100, + seed=42 + ) + dataset = ArithmeticDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify the answer matches the expression + expression = item["metadata"]["expression"] + answer = eval(expression) # Safe here as we control the expression + assert str(answer) == item["answer"] + + +def test_arithmetic_dataset_format_styles(): + """Test different question format styles""" + config = ArithmeticDatasetConfig( + size=10, + seed=42, + format_style="simple" + ) + dataset = ArithmeticDataset(config) + assert all(item["question"].endswith("=") for item in dataset) + + config.format_style = "natural" + dataset = ArithmeticDataset(config) + assert all("=" not in item["question"] for item in dataset)