mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
formatting
This commit is contained in:
parent
98988c8481
commit
20069b2a7d
37 changed files with 504 additions and 666 deletions
|
|
@ -2,12 +2,7 @@
|
||||||
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import arithmetic
|
from . import algorithmic, arithmetic, cognition, data, games, logic
|
||||||
from . import algorithmic
|
|
||||||
from . import cognition
|
|
||||||
from . import data
|
|
||||||
from . import games
|
|
||||||
from . import logic
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]
|
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities:
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
|
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
|
||||||
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
|
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
|
||||||
|
|
||||||
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
|
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
|
||||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
||||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
||||||
|
|
@ -31,5 +32,5 @@ __all__ = [
|
||||||
"number_sorting_dataset",
|
"number_sorting_dataset",
|
||||||
"WordReversalConfig",
|
"WordReversalConfig",
|
||||||
"WordReversalDataset",
|
"WordReversalDataset",
|
||||||
"word_reversal_dataset"
|
"word_reversal_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,20 @@
|
||||||
"""Base conversion task generator"""
|
"""Base conversion task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseConversionConfig:
|
class BaseConversionConfig:
|
||||||
"""Configuration for base conversion task generation"""
|
"""Configuration for base conversion task generation"""
|
||||||
min_base: int = 2 # Minimum base (2=binary)
|
|
||||||
max_base: int = 16 # Maximum base (16=hex)
|
min_base: int = 2 # Minimum base (2=binary)
|
||||||
min_value: int = 0 # Minimum decimal value to convert
|
max_base: int = 16 # Maximum base (16=hex)
|
||||||
max_value: int = 1000 # Maximum decimal value to convert
|
min_value: int = 0 # Minimum decimal value to convert
|
||||||
|
max_value: int = 1000 # Maximum decimal value to convert
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -71,14 +74,14 @@ class BaseConversionDataset:
|
||||||
value, source_base, target_base = self._generate_conversion(rng)
|
value, source_base, target_base = self._generate_conversion(rng)
|
||||||
|
|
||||||
# Convert decimal to source base representation
|
# Convert decimal to source base representation
|
||||||
source_repr = format(value, f'x' if source_base == 16 else f'b' if source_base == 2 else '').strip()
|
source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
|
||||||
if source_base not in (2, 16):
|
if source_base not in (2, 16):
|
||||||
source_repr = format(value, f'{source_base}x').lower().strip()
|
source_repr = format(value, f"{source_base}x").lower().strip()
|
||||||
|
|
||||||
# Convert decimal to target base for answer
|
# Convert decimal to target base for answer
|
||||||
target_repr = format(value, f'x' if target_base == 16 else f'b' if target_base == 2 else '').strip()
|
target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
|
||||||
if target_base not in (2, 16):
|
if target_base not in (2, 16):
|
||||||
target_repr = format(value, f'{target_base}x').lower().strip()
|
target_repr = format(value, f"{target_base}x").lower().strip()
|
||||||
|
|
||||||
source_name = self._format_base_name(source_base)
|
source_name = self._format_base_name(source_base)
|
||||||
target_name = self._format_base_name(target_base)
|
target_name = self._format_base_name(target_base)
|
||||||
|
|
@ -94,8 +97,8 @@ class BaseConversionDataset:
|
||||||
"source_base": source_base,
|
"source_base": source_base,
|
||||||
"target_base": target_base,
|
"target_base": target_base,
|
||||||
"source_repr": source_repr,
|
"source_repr": source_repr,
|
||||||
"target_repr": target_repr
|
"target_repr": target_repr,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
"""Letter counting task generator"""
|
"""Letter counting task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LetterCountingConfig:
|
class LetterCountingConfig:
|
||||||
"""Configuration for letter counting task generation"""
|
"""Configuration for letter counting task generation"""
|
||||||
min_words: int = 5 # Minimum words in span
|
|
||||||
max_words: int = 15 # Maximum words in span
|
min_words: int = 5 # Minimum words in span
|
||||||
|
max_words: int = 15 # Maximum words in span
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -31,7 +34,7 @@ class LetterCountingDataset:
|
||||||
# Load and preprocess text
|
# Load and preprocess text
|
||||||
text = read_data_file("in_the_year_2889.txt")
|
text = read_data_file("in_the_year_2889.txt")
|
||||||
# Extract words and clean them to contain only alphanumeric characters
|
# Extract words and clean them to contain only alphanumeric characters
|
||||||
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()]
|
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -54,12 +57,12 @@ class LetterCountingDataset:
|
||||||
# Select random span of words
|
# Select random span of words
|
||||||
span_length = rng.randint(self.config.min_words, self.config.max_words)
|
span_length = rng.randint(self.config.min_words, self.config.max_words)
|
||||||
start_idx = rng.randint(0, len(self.words) - span_length)
|
start_idx = rng.randint(0, len(self.words) - span_length)
|
||||||
span = self.words[start_idx:start_idx + span_length]
|
span = self.words[start_idx : start_idx + span_length]
|
||||||
|
|
||||||
# Get all unique letters from span
|
# Get all unique letters from span
|
||||||
letters = set(''.join(span).lower())
|
letters = set("".join(span).lower())
|
||||||
if not letters:
|
if not letters:
|
||||||
letters = {'a'} # Fallback if span has no letters
|
letters = {"a"} # Fallback if span has no letters
|
||||||
|
|
||||||
# Select random letter that appears in the span
|
# Select random letter that appears in the span
|
||||||
target_letter = rng.choice(list(letters))
|
target_letter = rng.choice(list(letters))
|
||||||
|
|
@ -70,11 +73,7 @@ class LetterCountingDataset:
|
||||||
return {
|
return {
|
||||||
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
||||||
"answer": str(count),
|
"answer": str(count),
|
||||||
"metadata": {
|
"metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
|
||||||
"span_length": span_length,
|
|
||||||
"target_letter": target_letter,
|
|
||||||
"span": span
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
"""Number filtering task generator"""
|
"""Number filtering task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NumberFilteringConfig:
|
class NumberFilteringConfig:
|
||||||
"""Configuration for number filtering task generation"""
|
"""Configuration for number filtering task generation"""
|
||||||
min_numbers: int = 3 # Minimum numbers in list
|
|
||||||
max_numbers: int = 10 # Maximum numbers in list
|
min_numbers: int = 3 # Minimum numbers in list
|
||||||
min_decimals: int = 0 # Minimum decimal places
|
max_numbers: int = 10 # Maximum numbers in list
|
||||||
max_decimals: int = 4 # Maximum decimal places
|
min_decimals: int = 0 # Minimum decimal places
|
||||||
min_value: float = -100.0 # Minimum number value
|
max_decimals: int = 4 # Maximum decimal places
|
||||||
max_value: float = 100.0 # Maximum number value
|
min_value: float = -100.0 # Minimum number value
|
||||||
|
max_value: float = 100.0 # Maximum number value
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -96,15 +99,17 @@ class NumberFilteringDataset:
|
||||||
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
"question": (
|
||||||
f"in this list: {str_numbers}"),
|
f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
||||||
|
f"in this list: {str_numbers}"
|
||||||
|
),
|
||||||
"answer": str(result_strs) if result_strs else "[]",
|
"answer": str(result_strs) if result_strs else "[]",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_numbers": str_numbers,
|
"original_numbers": str_numbers,
|
||||||
"filter_value": filter_str,
|
"filter_value": filter_str,
|
||||||
"operation": f"{keep_remove}_{larger_smaller}",
|
"operation": f"{keep_remove}_{larger_smaller}",
|
||||||
"result": result_strs
|
"result": result_strs,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
"""Number sorting task generator"""
|
"""Number sorting task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NumberSortingConfig:
|
class NumberSortingConfig:
|
||||||
"""Configuration for number sorting task generation"""
|
"""Configuration for number sorting task generation"""
|
||||||
min_numbers: int = 3 # Minimum numbers to sort
|
|
||||||
max_numbers: int = 10 # Maximum numbers to sort
|
min_numbers: int = 3 # Minimum numbers to sort
|
||||||
min_decimals: int = 0 # Minimum decimal places
|
max_numbers: int = 10 # Maximum numbers to sort
|
||||||
max_decimals: int = 2 # Maximum decimal places
|
min_decimals: int = 0 # Minimum decimal places
|
||||||
|
max_decimals: int = 2 # Maximum decimal places
|
||||||
min_value: float = -100.0 # Minimum value
|
min_value: float = -100.0 # Minimum value
|
||||||
max_value: float = 100.0 # Maximum value
|
max_value: float = 100.0 # Maximum value
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -82,7 +85,7 @@ class NumberSortingDataset:
|
||||||
desc_numbers = sorted(numbers, reverse=True)
|
desc_numbers = sorted(numbers, reverse=True)
|
||||||
|
|
||||||
# Format answers as string lists
|
# Format answers as string lists
|
||||||
decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0
|
decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
|
||||||
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
|
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
|
||||||
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
|
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
|
||||||
|
|
||||||
|
|
@ -94,11 +97,7 @@ class NumberSortingDataset:
|
||||||
return {
|
return {
|
||||||
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {
|
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
||||||
"original_numbers": number_strs,
|
|
||||||
"direction": direction,
|
|
||||||
"sorted_numbers": answer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
"""Word reversal task generator"""
|
"""Word reversal task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WordReversalConfig:
|
class WordReversalConfig:
|
||||||
"""Configuration for word reversal task generation"""
|
"""Configuration for word reversal task generation"""
|
||||||
min_words: int = 3 # Minimum words in list
|
|
||||||
max_words: int = 8 # Maximum words in list
|
min_words: int = 3 # Minimum words in list
|
||||||
|
max_words: int = 8 # Maximum words in list
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -31,7 +34,7 @@ class WordReversalDataset:
|
||||||
# Load and preprocess text
|
# Load and preprocess text
|
||||||
text = read_data_file("in_the_year_2889.txt")
|
text = read_data_file("in_the_year_2889.txt")
|
||||||
# Extract words and clean them to contain only alphanumeric characters
|
# Extract words and clean them to contain only alphanumeric characters
|
||||||
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()]
|
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -63,10 +66,7 @@ class WordReversalDataset:
|
||||||
return {
|
return {
|
||||||
"question": f"Reverse this list of words: {question}",
|
"question": f"Reverse this list of words: {question}",
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"metadata": {
|
"metadata": {"num_words": num_words, "words": words},
|
||||||
"num_words": num_words,
|
|
||||||
"words": words
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
|
||||||
|
|
||||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
||||||
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset
|
from .fraction_simplification import (
|
||||||
|
FractionSimplificationConfig,
|
||||||
|
FractionSimplificationDataset,
|
||||||
|
fraction_simplification_dataset,
|
||||||
|
)
|
||||||
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
||||||
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
||||||
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
||||||
|
|
@ -35,5 +39,5 @@ __all__ = [
|
||||||
"leg_counting_dataset",
|
"leg_counting_dataset",
|
||||||
"PrimeFactorizationConfig",
|
"PrimeFactorizationConfig",
|
||||||
"PrimeFactorizationDataset",
|
"PrimeFactorizationDataset",
|
||||||
"prime_factorization_dataset"
|
"prime_factorization_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
expression = " ".join(expression_parts)
|
expression = " ".join(expression_parts)
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
def _format_question(self, rng: Random, expression: str) -> str:
|
def _format_question(self, rng: Random, expression: str) -> str:
|
||||||
"""Format the expression according to config style"""
|
"""Format the expression according to config style"""
|
||||||
if self.config.format_style == "simple":
|
if self.config.format_style == "simple":
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
||||||
"""Generate a chain sum task
|
"""Generate a chain sum task
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Fraction simplification task generator"""
|
"""Fraction simplification task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
|
||||||
from typing import Optional, Tuple, Sequence
|
|
||||||
from ..dataset import ProceduralDataset
|
|
||||||
from math import gcd
|
from math import gcd
|
||||||
|
from random import Random
|
||||||
|
from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FractionSimplificationConfig:
|
class FractionSimplificationConfig:
|
||||||
"""Configuration for fraction simplification task generation"""
|
"""Configuration for fraction simplification task generation"""
|
||||||
min_value: int = 1 # Minimum value for numerator/denominator
|
|
||||||
max_value: int = 1000 # Maximum value for numerator/denominator
|
min_value: int = 1 # Minimum value for numerator/denominator
|
||||||
min_factor: int = 1 # Minimum multiplication factor
|
max_value: int = 1000 # Maximum value for numerator/denominator
|
||||||
max_factor: int = 100 # Maximum multiplication factor
|
min_factor: int = 1 # Minimum multiplication factor
|
||||||
|
max_factor: int = 100 # Maximum multiplication factor
|
||||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -53,8 +56,10 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
simplified_den //= common
|
simplified_den //= common
|
||||||
|
|
||||||
# Check if simplified fraction is within bounds
|
# Check if simplified fraction is within bounds
|
||||||
if (self.config.min_value <= simplified_num <= self.config.max_value and
|
if (
|
||||||
self.config.min_value <= simplified_den <= self.config.max_value):
|
self.config.min_value <= simplified_num <= self.config.max_value
|
||||||
|
and self.config.min_value <= simplified_den <= self.config.max_value
|
||||||
|
):
|
||||||
# Ensure numerator is smaller than denominator
|
# Ensure numerator is smaller than denominator
|
||||||
if simplified_num > simplified_den:
|
if simplified_num > simplified_den:
|
||||||
simplified_num, simplified_den = simplified_den, simplified_num
|
simplified_num, simplified_den = simplified_den, simplified_num
|
||||||
|
|
@ -75,8 +80,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
simplified_num, simplified_den = simplified_den, simplified_num
|
simplified_num, simplified_den = simplified_den, simplified_num
|
||||||
|
|
||||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||||
return (simplified_num * factor, simplified_den * factor,
|
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||||
simplified_num, simplified_den)
|
|
||||||
|
|
||||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||||
"""Format a fraction in various styles"""
|
"""Format a fraction in various styles"""
|
||||||
|
|
@ -99,7 +103,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||||
|
|
||||||
# Choose a random style from configured styles
|
# Choose a random style from configured styles
|
||||||
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)]
|
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||||
|
|
||||||
# Format both question and answer in the same style
|
# Format both question and answer in the same style
|
||||||
question_fraction = self._format_fraction(num, den, style)
|
question_fraction = self._format_fraction(num, den, style)
|
||||||
|
|
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
"simplified_numerator": simple_num,
|
"simplified_numerator": simple_num,
|
||||||
"simplified_denominator": simple_den,
|
"simplified_denominator": simple_den,
|
||||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||||
"style": style
|
"style": style,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Greatest Common Divisor (GCD) task generator"""
|
"""Greatest Common Divisor (GCD) task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
from math import gcd
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GCDConfig:
|
class GCDConfig:
|
||||||
"""Configuration for GCD task generation"""
|
"""Configuration for GCD task generation"""
|
||||||
min_numbers: int = 2 # Minimum numbers to find GCD of
|
|
||||||
max_numbers: int = 2 # Maximum numbers to find GCD of
|
min_numbers: int = 2 # Minimum numbers to find GCD of
|
||||||
min_value: int = 1 # Minimum value for each number
|
max_numbers: int = 2 # Maximum numbers to find GCD of
|
||||||
max_value: int = 1000 # Maximum value for each number
|
min_value: int = 1 # Minimum value for each number
|
||||||
|
max_value: int = 1000 # Maximum value for each number
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -38,16 +41,14 @@ class GCDDataset(ProceduralDataset):
|
||||||
Will try up to 3 times to find numbers with GCD > 1."""
|
Will try up to 3 times to find numbers with GCD > 1."""
|
||||||
for _ in range(3): # Try up to 3 times to get GCD > 1
|
for _ in range(3): # Try up to 3 times to get GCD > 1
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(gcd, numbers)
|
result = reduce(gcd, numbers)
|
||||||
if result > 1:
|
if result > 1:
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
# If we failed to find GCD > 1 after 3 tries, generate one final set
|
# If we failed to find GCD > 1 after 3 tries, generate one final set
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(gcd, numbers)
|
result = reduce(gcd, numbers)
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
|
|
@ -61,10 +62,7 @@ class GCDDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {"numbers": numbers, "result": result},
|
||||||
"numbers": numbers,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Least Common Multiple (LCM) task generator"""
|
"""Least Common Multiple (LCM) task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import lcm
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
from math import lcm
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LCMConfig:
|
class LCMConfig:
|
||||||
"""Configuration for LCM task generation"""
|
"""Configuration for LCM task generation"""
|
||||||
min_numbers: int = 2 # Minimum numbers to find LCM of
|
|
||||||
max_numbers: int = 2 # Maximum numbers to find LCM of
|
min_numbers: int = 2 # Minimum numbers to find LCM of
|
||||||
min_value: int = 1 # Minimum value for each number
|
max_numbers: int = 2 # Maximum numbers to find LCM of
|
||||||
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
min_value: int = 1 # Minimum value for each number
|
||||||
|
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -36,21 +39,20 @@ class LCMDataset(ProceduralDataset):
|
||||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||||
"""Generate a list of random positive integers and their LCM.
|
"""Generate a list of random positive integers and their LCM.
|
||||||
Will try up to 3 times to find numbers with LCM < product."""
|
Will try up to 3 times to find numbers with LCM < product."""
|
||||||
|
|
||||||
def calculate_product(nums: List[int]) -> int:
|
def calculate_product(nums: List[int]) -> int:
|
||||||
return reduce(lambda x, y: x * y, nums)
|
return reduce(lambda x, y: x * y, nums)
|
||||||
|
|
||||||
for _ in range(3): # Try up to 3 times to get LCM < product
|
for _ in range(3): # Try up to 3 times to get LCM < product
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(lcm, numbers)
|
result = reduce(lcm, numbers)
|
||||||
if result < calculate_product(numbers):
|
if result < calculate_product(numbers):
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
# If we failed to find LCM < product after 3 tries, generate one final set
|
# If we failed to find LCM < product after 3 tries, generate one final set
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(lcm, numbers)
|
result = reduce(lcm, numbers)
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
|
|
@ -64,10 +66,7 @@ class LCMDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {"numbers": numbers, "result": result},
|
||||||
"numbers": numbers,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
"""Leg counting task generator"""
|
"""Leg counting task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
ANIMALS = {
|
ANIMALS = {
|
||||||
|
|
@ -52,14 +54,16 @@ ANIMALS = {
|
||||||
"woodlouse": 14,
|
"woodlouse": 14,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LegCountingConfig:
|
class LegCountingConfig:
|
||||||
"""Configuration for leg counting task generation"""
|
"""Configuration for leg counting task generation"""
|
||||||
min_animals: int = 2 # Minimum number of animals in problem
|
|
||||||
max_animals: int = 5 # Maximum number of animals
|
min_animals: int = 2 # Minimum number of animals in problem
|
||||||
max_instances: int = 3 # Maximum instances of each animal
|
max_animals: int = 5 # Maximum number of animals
|
||||||
|
max_instances: int = 3 # Maximum instances of each animal
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -109,10 +113,7 @@ class LegCountingDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": question,
|
||||||
"answer": str(total_legs),
|
"answer": str(total_legs),
|
||||||
"metadata": {
|
"metadata": {"animals": animals, "total_legs": total_legs},
|
||||||
"animals": animals,
|
|
||||||
"total_legs": total_legs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,20 @@
|
||||||
"""Prime factorization task generator"""
|
"""Prime factorization task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrimeFactorizationConfig:
|
class PrimeFactorizationConfig:
|
||||||
"""Configuration for prime factorization task generation"""
|
"""Configuration for prime factorization task generation"""
|
||||||
min_value: int = 2 # Minimum number to factorize
|
|
||||||
max_value: int = 1000 # Maximum number to factorize
|
min_value: int = 2 # Minimum number to factorize
|
||||||
|
max_value: int = 1000 # Maximum number to factorize
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -55,13 +59,12 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
||||||
answer = " × ".join(map(str, factors))
|
answer = " × ".join(map(str, factors))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": (f"Find the prime factorization of {number}. Write the factors separated by × "
|
"question": (
|
||||||
f"(Example: for 12 the answer would be: 2 × 2 × 3)"),
|
f"Find the prime factorization of {number}. Write the factors separated by × "
|
||||||
|
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
|
||||||
|
),
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"metadata": {
|
"metadata": {"number": number, "factors": factors},
|
||||||
"number": number,
|
|
||||||
"factors": factors
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class SequenceConfig:
|
||||||
class PatternRule:
|
class PatternRule:
|
||||||
"""Represents a composable sequence pattern rule"""
|
"""Represents a composable sequence pattern rule"""
|
||||||
|
|
||||||
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None):
|
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List["PatternRule"] = None):
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
self.subrules = subrules or []
|
self.subrules = subrules or []
|
||||||
|
|
@ -66,14 +66,14 @@ class PatternRule:
|
||||||
elif op == Operation.COMPOSE:
|
elif op == Operation.COMPOSE:
|
||||||
# Apply each subrule in sequence, passing the result through
|
# Apply each subrule in sequence, passing the result through
|
||||||
for subrule in self.subrules:
|
for subrule in self.subrules:
|
||||||
temp_sequence = sequence[:position + 1]
|
temp_sequence = sequence[: position + 1]
|
||||||
temp_sequence[-1] = result # Use current result as input
|
temp_sequence[-1] = result # Use current result as input
|
||||||
result = subrule.apply(temp_sequence, position)
|
result = subrule.apply(temp_sequence, position)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compose(cls, rules: List['PatternRule']) -> 'PatternRule':
|
def compose(cls, rules: List["PatternRule"]) -> "PatternRule":
|
||||||
"""Create a new rule that composes multiple rules together"""
|
"""Create a new rule that composes multiple rules together"""
|
||||||
return cls([Operation.COMPOSE], [0], subrules=rules)
|
return cls([Operation.COMPOSE], [0], subrules=rules)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from importlib import resources
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
def get_data_file_path(filename: str) -> Path:
|
def get_data_file_path(filename: str) -> Path:
|
||||||
"""Get the path to a data file in the package.
|
"""Get the path to a data file in the package.
|
||||||
|
|
||||||
|
|
@ -18,7 +19,8 @@ def get_data_file_path(filename: str) -> Path:
|
||||||
>>> with open(path) as f:
|
>>> with open(path) as f:
|
||||||
... content = f.read()
|
... content = f.read()
|
||||||
"""
|
"""
|
||||||
return resources.files('reasoning_gym.data').joinpath(filename)
|
return resources.files("reasoning_gym.data").joinpath(filename)
|
||||||
|
|
||||||
|
|
||||||
def read_data_file(filename: str) -> str:
|
def read_data_file(filename: str) -> str:
|
||||||
"""Read the contents of a data file in the package.
|
"""Read the contents of a data file in the package.
|
||||||
|
|
@ -32,6 +34,7 @@ def read_data_file(filename: str) -> str:
|
||||||
Example:
|
Example:
|
||||||
>>> content = read_data_file("pg19362.txt")
|
>>> content = read_data_file("pg19362.txt")
|
||||||
"""
|
"""
|
||||||
return resources.files('reasoning_gym.data').joinpath(filename).read_text()
|
return resources.files("reasoning_gym.data").joinpath(filename).read_text()
|
||||||
|
|
||||||
__all__ = ['get_data_file_path', 'read_data_file']
|
|
||||||
|
__all__ = ["get_data_file_path", "read_data_file"]
|
||||||
|
|
|
||||||
|
|
@ -1048,5 +1048,3 @@ This website includes information about Project Gutenberg™,
|
||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
"""Base class for procedural dataset generators"""
|
"""Base class for procedural dataset generators"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sized, Iterable
|
from collections.abc import Iterable, Sized
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Iterator, Dict, Any
|
from typing import Any, Dict, Iterator, Optional
|
||||||
|
|
||||||
|
|
||||||
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||||
|
|
|
||||||
|
|
@ -14,5 +14,5 @@ __all__ = [
|
||||||
"mini_sudoku_dataset",
|
"mini_sudoku_dataset",
|
||||||
"SudokuConfig",
|
"SudokuConfig",
|
||||||
"SudokuDataset",
|
"SudokuDataset",
|
||||||
"sudoku_dataset"
|
"sudoku_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,19 @@
|
||||||
"""Mini Sudoku (4x4) puzzle generator"""
|
"""Mini Sudoku (4x4) puzzle generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MiniSudokuConfig:
|
class MiniSudokuConfig:
|
||||||
"""Configuration for 4x4 sudoku puzzle generation"""
|
"""Configuration for 4x4 sudoku puzzle generation"""
|
||||||
min_empty: int = 8 # Minimum number of empty cells
|
|
||||||
max_empty: int = 12 # Maximum number of empty cells
|
min_empty: int = 8 # Minimum number of empty cells
|
||||||
|
max_empty: int = 12 # Maximum number of empty cells
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -142,11 +145,7 @@ class MiniSudokuDataset:
|
||||||
return {
|
return {
|
||||||
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
|
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
|
||||||
"answer": solution_str,
|
"answer": solution_str,
|
||||||
"metadata": {
|
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
||||||
"puzzle": puzzle,
|
|
||||||
"solution": solved_board,
|
|
||||||
"num_empty": num_empty
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,19 @@
|
||||||
"""Sudoku puzzle generator"""
|
"""Sudoku puzzle generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SudokuConfig:
|
class SudokuConfig:
|
||||||
"""Configuration for sudoku puzzle generation"""
|
"""Configuration for sudoku puzzle generation"""
|
||||||
min_empty: int = 30 # Minimum number of empty cells
|
|
||||||
max_empty: int = 50 # Maximum number of empty cells
|
min_empty: int = 30 # Minimum number of empty cells
|
||||||
|
max_empty: int = 50 # Maximum number of empty cells
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -132,11 +135,7 @@ class SudokuDataset:
|
||||||
return {
|
return {
|
||||||
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
|
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
|
||||||
"answer": solution_str,
|
"answer": solution_str,
|
||||||
"metadata": {
|
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
||||||
"puzzle": puzzle,
|
|
||||||
"solution": solved_board,
|
|
||||||
"num_empty": num_empty
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import pytest
|
|
||||||
from random import Random
|
from random import Random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,14 +32,7 @@ def test_arithmetic_dataset_deterministic():
|
||||||
|
|
||||||
def test_arithmetic_dataset_items():
|
def test_arithmetic_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = BasicArithmeticDatasetConfig(
|
config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=4,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -62,7 +57,7 @@ def test_arithmetic_dataset_format_styles():
|
||||||
min_terms=2,
|
min_terms=2,
|
||||||
max_terms=3, # Keep expressions simple for testing
|
max_terms=3, # Keep expressions simple for testing
|
||||||
min_digits=1,
|
min_digits=1,
|
||||||
max_digits=2
|
max_digits=2,
|
||||||
)
|
)
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
assert all(item["question"].endswith("=") for item in dataset)
|
assert all(item["question"].endswith("=") for item in dataset)
|
||||||
|
|
@ -74,12 +69,7 @@ def test_arithmetic_dataset_format_styles():
|
||||||
|
|
||||||
def test_arithmetic_dataset_iteration():
|
def test_arithmetic_dataset_iteration():
|
||||||
"""Test that iteration respects dataset size"""
|
"""Test that iteration respects dataset size"""
|
||||||
config = BasicArithmeticDatasetConfig(
|
config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
size=5, # Small size for testing
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for base conversion task generation"""
|
"""Tests for base conversion task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.base_conversion import (
|
from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||||
BaseConversionConfig,
|
|
||||||
BaseConversionDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_conversion_config_validation():
|
def test_base_conversion_config_validation():
|
||||||
|
|
@ -38,14 +36,7 @@ def test_base_conversion_dataset_deterministic():
|
||||||
|
|
||||||
def test_base_conversion_dataset_items():
|
def test_base_conversion_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = BaseConversionConfig(
|
config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
|
||||||
min_base=2,
|
|
||||||
max_base=16,
|
|
||||||
min_value=0,
|
|
||||||
max_value=1000,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -74,9 +65,9 @@ def test_base_conversion_dataset_items():
|
||||||
# Verify conversion correctness
|
# Verify conversion correctness
|
||||||
decimal_value = item["metadata"]["decimal_value"]
|
decimal_value = item["metadata"]["decimal_value"]
|
||||||
target_base = item["metadata"]["target_base"]
|
target_base = item["metadata"]["target_base"]
|
||||||
expected = format(decimal_value, 'x' if target_base == 16 else 'b' if target_base == 2 else '').strip()
|
expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip()
|
||||||
if target_base not in (2, 16):
|
if target_base not in (2, 16):
|
||||||
expected = format(decimal_value, f'{target_base}x').lower().strip()
|
expected = format(decimal_value, f"{target_base}x").lower().strip()
|
||||||
assert item["answer"] == expected
|
assert item["answer"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -100,7 +91,7 @@ def test_base_conversion_special_bases():
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=255, # Use small range for predictable results
|
max_value=255, # Use small range for predictable results
|
||||||
size=100,
|
size=100,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
|
|
@ -112,11 +103,11 @@ def test_base_conversion_special_bases():
|
||||||
if item["metadata"]["target_base"] == 2:
|
if item["metadata"]["target_base"] == 2:
|
||||||
binary_found = True
|
binary_found = True
|
||||||
# Verify binary format
|
# Verify binary format
|
||||||
assert all(c in '01' for c in item["answer"])
|
assert all(c in "01" for c in item["answer"])
|
||||||
elif item["metadata"]["target_base"] == 16:
|
elif item["metadata"]["target_base"] == 16:
|
||||||
hex_found = True
|
hex_found = True
|
||||||
# Verify hex format
|
# Verify hex format
|
||||||
assert all(c in '0123456789abcdef' for c in item["answer"])
|
assert all(c in "0123456789abcdef" for c in item["answer"])
|
||||||
|
|
||||||
assert binary_found, "No binary conversion tasks generated"
|
assert binary_found, "No binary conversion tasks generated"
|
||||||
assert hex_found, "No hexadecimal conversion tasks generated"
|
assert hex_found, "No hexadecimal conversion tasks generated"
|
||||||
|
|
@ -130,7 +121,7 @@ def test_base_conversion_formatting():
|
||||||
min_value=10, # Ensure multi-digit numbers
|
min_value=10, # Ensure multi-digit numbers
|
||||||
max_value=1000,
|
max_value=1000,
|
||||||
size=10,
|
size=10,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,14 +26,7 @@ def test_chain_sum_deterministic():
|
||||||
|
|
||||||
def test_chain_sum_items():
|
def test_chain_sum_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=4,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -60,7 +54,7 @@ def test_chain_sum_number_ranges():
|
||||||
min_digits=3, # Should generate numbers >= 100
|
min_digits=3, # Should generate numbers >= 100
|
||||||
max_digits=3, # Should generate numbers <= 999
|
max_digits=3, # Should generate numbers <= 999
|
||||||
size=50,
|
size=50,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
|
|
@ -74,16 +68,8 @@ def test_chain_sum_number_ranges():
|
||||||
else:
|
else:
|
||||||
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||||
|
|
||||||
|
|
||||||
# Test 1-digit numbers
|
# Test 1-digit numbers
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=1,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
|
|
@ -95,16 +81,11 @@ def test_chain_sum_number_ranges():
|
||||||
else:
|
else:
|
||||||
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||||
|
|
||||||
|
|
||||||
def test_chain_sum_negation():
|
def test_chain_sum_negation():
|
||||||
"""Test that allow_negation controls number ranges"""
|
"""Test that allow_negation controls number ranges"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(
|
||||||
min_terms=2,
|
min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
|
||||||
max_terms=2,
|
|
||||||
min_digits=2,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42,
|
|
||||||
allow_negation=True
|
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
|
|
@ -115,7 +96,7 @@ def test_chain_sum_negation():
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith('-') and n[1:].isdigit())]
|
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())]
|
||||||
|
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
if num > 0:
|
if num > 0:
|
||||||
|
|
@ -129,12 +110,7 @@ def test_chain_sum_negation():
|
||||||
|
|
||||||
def test_chain_sum_iteration():
|
def test_chain_sum_iteration():
|
||||||
"""Test that iteration respects dataset size"""
|
"""Test that iteration respects dataset size"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
size=5, # Small size for testing
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
|
||||||
from math import gcd
|
from math import gcd
|
||||||
from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
|
|
||||||
|
|
||||||
def test_fraction_config_validation():
|
def test_fraction_config_validation():
|
||||||
|
|
@ -34,14 +36,7 @@ def test_fraction_deterministic():
|
||||||
|
|
||||||
def test_fraction_items():
|
def test_fraction_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=1, max_value=20, min_factor=2, max_factor=5, size=50, seed=42)
|
||||||
min_value=1,
|
|
||||||
max_value=20,
|
|
||||||
min_factor=2,
|
|
||||||
max_factor=5,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -79,14 +74,7 @@ def test_fraction_items():
|
||||||
|
|
||||||
def test_fraction_ranges():
|
def test_fraction_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=5, max_value=15, min_factor=3, max_factor=4, size=20, seed=42)
|
||||||
min_value=5,
|
|
||||||
max_value=15,
|
|
||||||
min_factor=3,
|
|
||||||
max_factor=4,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -125,14 +113,7 @@ def test_fraction_iteration():
|
||||||
|
|
||||||
def test_fraction_numerator_smaller():
|
def test_fraction_numerator_smaller():
|
||||||
"""Test that numerators are always smaller than denominators"""
|
"""Test that numerators are always smaller than denominators"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=1, max_value=100, min_factor=2, max_factor=5, size=50, seed=42)
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
min_factor=2,
|
|
||||||
max_factor=5,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -140,9 +121,11 @@ def test_fraction_numerator_smaller():
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
|
|
||||||
# Check original fraction
|
# Check original fraction
|
||||||
assert metadata["numerator"] <= metadata["denominator"], \
|
assert (
|
||||||
f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
|
metadata["numerator"] <= metadata["denominator"]
|
||||||
|
), f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
|
||||||
|
|
||||||
# Check simplified fraction
|
# Check simplified fraction
|
||||||
assert metadata["simplified_numerator"] <= metadata["simplified_denominator"], \
|
assert (
|
||||||
f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
metadata["simplified_numerator"] <= metadata["simplified_denominator"]
|
||||||
|
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import pytest
|
|
||||||
from math import gcd
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from reasoning_gym.arithmetic import GCDDataset, GCDConfig
|
from math import gcd
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
|
||||||
|
|
||||||
|
|
||||||
def test_gcd_config_validation():
|
def test_gcd_config_validation():
|
||||||
|
|
@ -35,14 +37,7 @@ def test_gcd_deterministic():
|
||||||
|
|
||||||
def test_gcd_items():
|
def test_gcd_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=4, min_value=1, max_value=100, size=50, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=4,
|
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -70,14 +65,7 @@ def test_gcd_items():
|
||||||
|
|
||||||
def test_gcd_number_ranges():
|
def test_gcd_number_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=2, min_value=50, max_value=100, size=20, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=50,
|
|
||||||
max_value=100,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -109,14 +97,7 @@ def test_gcd_iteration():
|
||||||
|
|
||||||
def test_gcd_special_cases():
|
def test_gcd_special_cases():
|
||||||
"""Test some special GCD cases"""
|
"""Test some special GCD cases"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=100, size=100, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
# Track if we see some interesting GCD cases
|
# Track if we see some interesting GCD cases
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import pytest
|
|
||||||
from math import lcm
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from reasoning_gym.arithmetic import LCMDataset, LCMConfig
|
from math import lcm
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import LCMConfig, LCMDataset
|
||||||
|
|
||||||
|
|
||||||
def test_lcm_config_validation():
|
def test_lcm_config_validation():
|
||||||
|
|
@ -36,12 +38,7 @@ def test_lcm_deterministic():
|
||||||
def test_lcm_items():
|
def test_lcm_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(
|
||||||
min_numbers=2,
|
min_numbers=2, max_numbers=4, min_value=1, max_value=20, size=50, seed=42 # Keep small for testing
|
||||||
max_numbers=4,
|
|
||||||
min_value=1,
|
|
||||||
max_value=20, # Keep small for testing
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
|
|
@ -70,14 +67,7 @@ def test_lcm_items():
|
||||||
|
|
||||||
def test_lcm_number_ranges():
|
def test_lcm_number_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(min_numbers=2, max_numbers=2, min_value=5, max_value=15, size=20, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=5,
|
|
||||||
max_value=15,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -109,14 +99,7 @@ def test_lcm_iteration():
|
||||||
|
|
||||||
def test_lcm_special_cases():
|
def test_lcm_special_cases():
|
||||||
"""Test some special LCM cases"""
|
"""Test some special LCM cases"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=20, size=100, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=1,
|
|
||||||
max_value=20,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
# Track if we see some interesting LCM cases
|
# Track if we see some interesting LCM cases
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
"""Tests for leg counting task generation"""
|
"""Tests for leg counting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.leg_counting import (
|
from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset
|
||||||
LegCountingConfig,
|
|
||||||
LegCountingDataset,
|
|
||||||
ANIMALS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_leg_counting_config_validation():
|
def test_leg_counting_config_validation():
|
||||||
|
|
@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic():
|
||||||
|
|
||||||
def test_leg_counting_dataset_items():
|
def test_leg_counting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LegCountingConfig(
|
config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42)
|
||||||
min_animals=2,
|
|
||||||
max_animals=4,
|
|
||||||
max_instances=2,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LegCountingDataset(config)
|
dataset = LegCountingDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for letter counting task generation"""
|
"""Tests for letter counting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.letter_counting import (
|
from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||||
LetterCountingConfig,
|
|
||||||
LetterCountingDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_letter_counting_config_validation():
|
def test_letter_counting_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_letter_counting_dataset_deterministic():
|
||||||
|
|
||||||
def test_letter_counting_dataset_items():
|
def test_letter_counting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LetterCountingConfig(
|
config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||||
min_words=3,
|
|
||||||
max_words=6,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LetterCountingDataset(config)
|
dataset = LetterCountingDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for mini sudoku puzzle generation"""
|
"""Tests for mini sudoku puzzle generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.mini_sudoku import (
|
from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
|
||||||
MiniSudokuConfig,
|
|
||||||
MiniSudokuDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mini_sudoku_config_validation():
|
def test_mini_sudoku_config_validation():
|
||||||
|
|
@ -34,12 +32,7 @@ def test_mini_sudoku_dataset_deterministic():
|
||||||
|
|
||||||
def test_mini_sudoku_dataset_items():
|
def test_mini_sudoku_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = MiniSudokuConfig(
|
config = MiniSudokuConfig(min_empty=8, max_empty=12, size=10, seed=42)
|
||||||
min_empty=8,
|
|
||||||
max_empty=12,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = MiniSudokuDataset(config)
|
dataset = MiniSudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -94,12 +87,7 @@ def test_mini_sudoku_dataset_iteration():
|
||||||
|
|
||||||
def test_mini_sudoku_board_generation():
|
def test_mini_sudoku_board_generation():
|
||||||
"""Test that generated boards are valid"""
|
"""Test that generated boards are valid"""
|
||||||
config = MiniSudokuConfig(
|
config = MiniSudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
|
||||||
min_empty=0, # Force complete board
|
|
||||||
max_empty=0,
|
|
||||||
size=5,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = MiniSudokuDataset(config)
|
dataset = MiniSudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -127,7 +115,7 @@ def is_valid_solution(board: list[list[int]]) -> bool:
|
||||||
box = []
|
box = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
box.append(board[box_i*2 + i][box_j*2 + j])
|
box.append(board[box_i * 2 + i][box_j * 2 + j])
|
||||||
if set(box) != set(range(1, 5)):
|
if set(box) != set(range(1, 5)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for number filtering task generation"""
|
"""Tests for number filtering task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.number_filtering import (
|
from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||||
NumberFilteringConfig,
|
|
||||||
NumberFilteringDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_number_filtering_config_validation():
|
def test_number_filtering_config_validation():
|
||||||
|
|
@ -39,14 +37,7 @@ def test_number_filtering_dataset_deterministic():
|
||||||
def test_number_filtering_dataset_items():
|
def test_number_filtering_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = NumberFilteringConfig(
|
config = NumberFilteringConfig(
|
||||||
min_numbers=3,
|
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||||
max_numbers=6,
|
|
||||||
min_decimals=1,
|
|
||||||
max_decimals=3,
|
|
||||||
min_value=-10.0,
|
|
||||||
max_value=10.0,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = NumberFilteringDataset(config)
|
dataset = NumberFilteringDataset(config)
|
||||||
|
|
||||||
|
|
@ -71,7 +62,7 @@ def test_number_filtering_dataset_items():
|
||||||
|
|
||||||
# Verify decimal places
|
# Verify decimal places
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
decimal_places = len(num.split('.')[-1]) if '.' in num else 0
|
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||||
assert decimal_places >= config.min_decimals
|
assert decimal_places >= config.min_decimals
|
||||||
assert decimal_places <= config.max_decimals
|
assert decimal_places <= config.max_decimals
|
||||||
|
|
||||||
|
|
@ -117,11 +108,11 @@ def test_number_filtering_precision():
|
||||||
min_value=0.0,
|
min_value=0.0,
|
||||||
max_value=1.0,
|
max_value=1.0,
|
||||||
size=1,
|
size=1,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = NumberFilteringDataset(config)
|
dataset = NumberFilteringDataset(config)
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
# Check that string representations maintain precision
|
# Check that string representations maintain precision
|
||||||
for num in item["metadata"]["original_numbers"]:
|
for num in item["metadata"]["original_numbers"]:
|
||||||
assert len(num.split('.')[-1]) == 2
|
assert len(num.split(".")[-1]) == 2
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for number sorting task generation"""
|
"""Tests for number sorting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.number_sorting import (
|
from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||||
NumberSortingConfig,
|
|
||||||
NumberSortingDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_number_sorting_config_validation():
|
def test_number_sorting_config_validation():
|
||||||
|
|
@ -39,14 +37,7 @@ def test_number_sorting_dataset_deterministic():
|
||||||
def test_number_sorting_dataset_items():
|
def test_number_sorting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = NumberSortingConfig(
|
config = NumberSortingConfig(
|
||||||
min_numbers=3,
|
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||||
max_numbers=6,
|
|
||||||
min_decimals=1,
|
|
||||||
max_decimals=3,
|
|
||||||
min_value=-10.0,
|
|
||||||
max_value=10.0,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = NumberSortingDataset(config)
|
dataset = NumberSortingDataset(config)
|
||||||
|
|
||||||
|
|
@ -70,7 +61,7 @@ def test_number_sorting_dataset_items():
|
||||||
|
|
||||||
# Verify decimal places
|
# Verify decimal places
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
decimal_places = len(num.split('.')[-1]) if '.' in num else 0
|
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||||
assert decimal_places >= config.min_decimals
|
assert decimal_places >= config.min_decimals
|
||||||
assert decimal_places <= config.max_decimals
|
assert decimal_places <= config.max_decimals
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for prime factorization task generation"""
|
"""Tests for prime factorization task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.prime_factorization import (
|
from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
PrimeFactorizationConfig,
|
|
||||||
PrimeFactorizationDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prime_factorization_config_validation():
|
def test_prime_factorization_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_prime_factorization_dataset_deterministic():
|
||||||
|
|
||||||
def test_prime_factorization_dataset_items():
|
def test_prime_factorization_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = PrimeFactorizationConfig(
|
config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42)
|
||||||
min_value=2,
|
|
||||||
max_value=100,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = PrimeFactorizationDataset(config)
|
dataset = PrimeFactorizationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -83,12 +76,7 @@ def test_prime_factorization_dataset_iteration():
|
||||||
|
|
||||||
def test_prime_factorization_known_values():
|
def test_prime_factorization_known_values():
|
||||||
"""Test factorization of known values"""
|
"""Test factorization of known values"""
|
||||||
config = PrimeFactorizationConfig(
|
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
|
||||||
min_value=12,
|
|
||||||
max_value=12, # Force specific number
|
|
||||||
size=1,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = PrimeFactorizationDataset(config)
|
dataset = PrimeFactorizationDataset(config)
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
|
|
@ -101,7 +89,7 @@ def is_prime(n: int) -> bool:
|
||||||
"""Helper function to check if a number is prime"""
|
"""Helper function to check if a number is prime"""
|
||||||
if n < 2:
|
if n < 2:
|
||||||
return False
|
return False
|
||||||
for i in range(2, int(n ** 0.5) + 1):
|
for i in range(2, int(n**0.5) + 1):
|
||||||
if n % i == 0:
|
if n % i == 0:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def test_pattern_rule():
|
||||||
|
|
||||||
# Test rule composition
|
# Test rule composition
|
||||||
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
|
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
|
||||||
rule2 = PatternRule([Operation.ADD], [3]) # Add 3
|
rule2 = PatternRule([Operation.ADD], [3]) # Add 3
|
||||||
composed = PatternRule.compose([rule1, rule2])
|
composed = PatternRule.compose([rule1, rule2])
|
||||||
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for sudoku puzzle generation"""
|
"""Tests for sudoku puzzle generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.sudoku import (
|
from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset
|
||||||
SudokuConfig,
|
|
||||||
SudokuDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sudoku_config_validation():
|
def test_sudoku_config_validation():
|
||||||
|
|
@ -34,12 +32,7 @@ def test_sudoku_dataset_deterministic():
|
||||||
|
|
||||||
def test_sudoku_dataset_items():
|
def test_sudoku_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = SudokuConfig(
|
config = SudokuConfig(min_empty=30, max_empty=40, size=10, seed=42)
|
||||||
min_empty=30,
|
|
||||||
max_empty=40,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SudokuDataset(config)
|
dataset = SudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -94,12 +87,7 @@ def test_sudoku_dataset_iteration():
|
||||||
|
|
||||||
def test_sudoku_board_generation():
|
def test_sudoku_board_generation():
|
||||||
"""Test that generated boards are valid"""
|
"""Test that generated boards are valid"""
|
||||||
config = SudokuConfig(
|
config = SudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
|
||||||
min_empty=0, # Force complete board
|
|
||||||
max_empty=0,
|
|
||||||
size=5,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SudokuDataset(config)
|
dataset = SudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -127,7 +115,7 @@ def is_valid_solution(board: list[list[int]]) -> bool:
|
||||||
box = []
|
box = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
for j in range(3):
|
for j in range(3):
|
||||||
box.append(board[box_i*3 + i][box_j*3 + j])
|
box.append(board[box_i * 3 + i][box_j * 3 + j])
|
||||||
if set(box) != set(range(1, 10)):
|
if set(box) != set(range(1, 10)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for word reversal task generation"""
|
"""Tests for word reversal task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.word_reversal import (
|
from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset
|
||||||
WordReversalConfig,
|
|
||||||
WordReversalDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_word_reversal_config_validation():
|
def test_word_reversal_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_word_reversal_dataset_deterministic():
|
||||||
|
|
||||||
def test_word_reversal_dataset_items():
|
def test_word_reversal_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = WordReversalConfig(
|
config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||||
min_words=3,
|
|
||||||
max_words=6,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = WordReversalDataset(config)
|
dataset = WordReversalDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue