diff --git a/README.md b/README.md index 48e5e5aa..622a7886 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ for item in dataset: ``` Example output: -```json +``` {'question': '4 + 3 =', 'answer': '7', 'metadata': {'num_terms': 2, 'num_digits': 1, 'expression': '4 + 3'}} {'question': '812 + 880 =', 'answer': '1692', 'metadata': {'num_terms': 2, 'num_digits': 3, 'expression': '812 + 880'}} {'question': '2 + 6 + 3 + 4 + 0 =', 'answer': '15', 'metadata': {'num_terms': 5, 'num_digits': 1, 'expression': '2 + 6 + 3 + 4 + 0'}} diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index b0c3dc5d..d9df2634 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from random import Random -from typing import Optional, Literal, Any +from typing import Any, Literal, Optional @dataclass @@ -69,7 +69,11 @@ class ArithmeticDataset: return { "question": question, "answer": str(result), - "metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression}, + "metadata": { + "num_terms": num_terms, + "num_digits": num_digits, + "expression": expression, + }, } def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: @@ -161,3 +165,47 @@ class ArithmeticDataset: # Use deterministic RNG for template selection template_rng = Random(self.seed) return template_rng.choice(templates).format(expression) + + +def arithmetic_dataset( + min_terms: int = 2, + max_terms: int = 6, + min_digits: int = 1, + max_digits: int = 4, + operators: list[str] = ("+", "-", "*"), + allow_parentheses: bool = True, + allow_negation: bool = True, + seed: Optional[int] = None, + size: int = 500, + format_style: Literal["simple", "natural"] = "simple", +) -> ArithmeticDataset: + """Create an ArithmeticDataset with the given configuration. + + Args: + min_terms: Minimum number of terms in expressions + max_terms: Maximum number of terms in expressions + min_digits: Minimum number of digits in numbers + max_digits: Maximum number of digits in numbers + operators: List of operators to use ("+", "-", "*") + allow_parentheses: Whether to allow parentheses in expressions + allow_negation: Whether to allow negative numbers + seed: Random seed for reproducibility + size: Virtual size of the dataset + format_style: Style of question formatting ("simple" or "natural") + + Returns: + ArithmeticDataset: Configured dataset instance + """ + config = ArithmeticDatasetConfig( + min_terms=min_terms, + max_terms=max_terms, + min_digits=min_digits, + max_digits=max_digits, + operators=operators, + allow_parentheses=allow_parentheses, + allow_negation=allow_negation, + seed=seed, + size=size, + format_style=format_style, + ) + return ArithmeticDataset(config) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 0ddfc49f..5fc22387 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass import random +from dataclasses import dataclass from typing import Optional @dataclass class ChainSumConfig: """Configuration for chain sum task generation""" + min_terms: int = 2 max_terms: int = 6 min_digits: int = 1 @@ -20,7 +21,7 @@ class ChainSumConfig: assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" assert self.min_digits > 0, "min_digits must be positive" assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" - + # Validate digit ranges make sense if self.min_digits > 1: assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" @@ -28,22 +29,22 @@ class ChainSumConfig: class ChainSum: """Generates simple arithmetic tasks using only + and - operators""" - + def __init__(self, config: ChainSumConfig): self.config = config self.config.validate() # Generate base seed if none provided self.seed = config.seed if config.seed is not None else random.randint(0, 2**32) - + def __len__(self) -> int: return self.config.size - + def __getitem__(self, idx: int) -> dict: """Generate a single chain sum task - + Args: idx: Index of the item to generate - + Returns: dict with keys: - question: str, the formatted arithmetic expression @@ -52,31 +53,31 @@ class ChainSum: """ # Create deterministic RNG from base seed and idx item_rng = random.Random(self.seed + idx) - + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) - + # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit - max_value = (10 ** num_digits) - 1 # e.g., 999 for 3 digits - + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) - + return { "question": f"{expression} =", "answer": str(result), "metadata": { "num_terms": num_terms, "num_digits": num_digits, - "expression": expression - } + "expression": expression, + }, } def __iter__(self): """Make the dataset iterable""" self._current_idx = 0 return self - + def __next__(self): """Get next item in iteration""" if self._current_idx >= self.config.size: @@ -87,13 +88,13 @@ class ChainSum: def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: """Generate a chain sum task - + Args: rng: Random number generator num_terms: Number of terms in the expression min_value: Minimum value for generated numbers max_value: Maximum value for generated numbers - + Returns: Tuple of (expression string, result integer) """ @@ -122,3 +123,38 @@ class ChainSum: expression = " ".join(expression_parts) return expression, result + + +def chain_sum( + min_terms: int = 2, + max_terms: int = 6, + min_digits: int = 1, + max_digits: int = 4, + allow_negation: bool = False, + seed: Optional[int] = None, + size: int = 500, +) -> ChainSum: + """Create a ChainSum dataset with the given configuration. + + Args: + min_terms: Minimum number of terms in expressions + max_terms: Maximum number of terms in expressions + min_digits: Minimum number of digits in numbers + max_digits: Maximum number of digits in numbers + allow_negation: Whether to allow negative numbers + seed: Random seed for reproducibility + size: Virtual size of the dataset + + Returns: + ChainSum: Configured dataset instance + """ + config = ChainSumConfig( + min_terms=min_terms, + max_terms=max_terms, + min_digits=min_digits, + max_digits=max_digits, + allow_negation=allow_negation, + seed=seed, + size=size, + ) + return ChainSum(config)