mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
style: Format code with consistent whitespace and remove unnecessary lines
This commit is contained in:
parent
38801a7e6f
commit
72ada57dc5
1 changed files with 14 additions and 24 deletions
|
|
@ -6,11 +6,12 @@ from typing import Optional, Literal, Any
|
|||
@dataclass
|
||||
class ArithmeticDatasetConfig:
|
||||
"""Configuration for arithmetic dataset generation"""
|
||||
|
||||
min_terms: int = 2
|
||||
max_terms: int = 6
|
||||
min_digits: int = 1
|
||||
max_digits: int = 4
|
||||
operators: list[str] = ("+" , "-", "*")
|
||||
operators: list[str] = ("+", "-", "*")
|
||||
allow_parentheses: bool = True
|
||||
allow_negation: bool = True
|
||||
seed: Optional[int] = None
|
||||
|
|
@ -30,22 +31,22 @@ class ArithmeticDatasetConfig:
|
|||
|
||||
class ArithmeticDataset:
|
||||
"""Dataset that generates arithmetic tasks with configurable complexity"""
|
||||
|
||||
|
||||
def __init__(self, config: ArithmeticDatasetConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
# Generate base seed if none provided
|
||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.config.size
|
||||
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
||||
|
||||
Args:
|
||||
idx: Index of the item to generate
|
||||
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the formatted arithmetic expression
|
||||
|
|
@ -54,25 +55,21 @@ class ArithmeticDataset:
|
|||
"""
|
||||
# Create deterministic RNG from base seed and idx
|
||||
item_rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
||||
|
||||
|
||||
if self.config.allow_parentheses:
|
||||
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
|
||||
else:
|
||||
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
|
||||
|
||||
|
||||
question = self._format_question(expression)
|
||||
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
"expression": expression
|
||||
}
|
||||
"metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression},
|
||||
}
|
||||
|
||||
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
|
||||
|
|
@ -142,13 +139,11 @@ class ArithmeticDataset:
|
|||
expression = " ".join(expression_parts)
|
||||
return expression, result
|
||||
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
"""Make the dataset iterable"""
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
|
||||
def __next__(self):
|
||||
"""Get next item in iteration"""
|
||||
if self._current_idx >= self.config.size:
|
||||
|
|
@ -162,12 +157,7 @@ class ArithmeticDataset:
|
|||
if self.config.format_style == "simple":
|
||||
return f"{expression} ="
|
||||
else:
|
||||
templates = [
|
||||
"What is {0}?",
|
||||
"Calculate {0}",
|
||||
"Solve {0}",
|
||||
"Evaluate the expression: {0}"
|
||||
]
|
||||
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
|
||||
# Use deterministic RNG for template selection
|
||||
template_rng = Random(self.seed)
|
||||
return template_rng.choice(templates).format(expression)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue