mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add configurable formatting styles for fraction simplification dataset
This commit is contained in:
parent
1567776fd8
commit
9826d98fcf
1 changed files with 14 additions and 3 deletions
|
|
@ -12,6 +12,7 @@ class FractionSimplificationConfig:
|
|||
max_value: int = 100 # Maximum value for numerator/denominator
|
||||
min_factor: int = 2 # Minimum multiplication factor
|
||||
max_factor: int = 10 # Maximum multiplication factor
|
||||
styles: List[str] = None # List of allowed fraction formatting styles
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
|
|
@ -21,6 +22,15 @@ class FractionSimplificationConfig:
|
|||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
assert self.min_factor >= 2, "min_factor must be at least 2"
|
||||
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
|
||||
|
||||
# Set default styles if none provided
|
||||
if self.styles is None:
|
||||
self.styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"]
|
||||
|
||||
# Validate styles
|
||||
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
|
||||
for style in self.styles:
|
||||
assert style in valid_styles, f"Invalid style: {style}. Must be one of {valid_styles}"
|
||||
|
||||
|
||||
class FractionSimplificationDataset:
|
||||
|
|
@ -104,9 +114,8 @@ class FractionSimplificationDataset:
|
|||
|
||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||
|
||||
# Choose a random style for this question
|
||||
styles = ["plain", "latex_inline", "latex_frac", "latex_dfrac"]
|
||||
style = styles[rng.randint(0, len(styles)-1)]
|
||||
# Choose a random style from configured styles
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)]
|
||||
|
||||
# Format both question and answer in the same style
|
||||
question_fraction = self._format_fraction(num, den, style)
|
||||
|
|
@ -131,6 +140,7 @@ def fraction_simplification_dataset(
|
|||
max_value: int = 100,
|
||||
min_factor: int = 2,
|
||||
max_factor: int = 10,
|
||||
styles: List[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> FractionSimplificationDataset:
|
||||
|
|
@ -140,6 +150,7 @@ def fraction_simplification_dataset(
|
|||
max_value=max_value,
|
||||
min_factor=min_factor,
|
||||
max_factor=max_factor,
|
||||
styles=styles,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue