mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
formatting
This commit is contained in:
parent
98988c8481
commit
20069b2a7d
37 changed files with 504 additions and 666 deletions
|
|
@ -2,12 +2,7 @@
|
||||||
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
Reasoning Gym - A library of procedural dataset generators for training reasoning models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import arithmetic
|
from . import algorithmic, arithmetic, cognition, data, games, logic
|
||||||
from . import algorithmic
|
|
||||||
from . import cognition
|
|
||||||
from . import data
|
|
||||||
from . import games
|
|
||||||
from . import logic
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]
|
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities:
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
|
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
|
||||||
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
|
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
|
||||||
|
|
||||||
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
|
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
|
||||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
||||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
||||||
|
|
@ -20,8 +21,8 @@ __all__ = [
|
||||||
"BaseConversionDataset",
|
"BaseConversionDataset",
|
||||||
"base_conversion_dataset",
|
"base_conversion_dataset",
|
||||||
"chain_sum_dataset",
|
"chain_sum_dataset",
|
||||||
"LetterCountingConfig",
|
"LetterCountingConfig",
|
||||||
"LetterCountingDataset",
|
"LetterCountingDataset",
|
||||||
"letter_counting_dataset",
|
"letter_counting_dataset",
|
||||||
"NumberFilteringConfig",
|
"NumberFilteringConfig",
|
||||||
"NumberFilteringDataset",
|
"NumberFilteringDataset",
|
||||||
|
|
@ -31,5 +32,5 @@ __all__ = [
|
||||||
"number_sorting_dataset",
|
"number_sorting_dataset",
|
||||||
"WordReversalConfig",
|
"WordReversalConfig",
|
||||||
"WordReversalDataset",
|
"WordReversalDataset",
|
||||||
"word_reversal_dataset"
|
"word_reversal_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,20 @@
|
||||||
"""Base conversion task generator"""
|
"""Base conversion task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseConversionConfig:
|
class BaseConversionConfig:
|
||||||
"""Configuration for base conversion task generation"""
|
"""Configuration for base conversion task generation"""
|
||||||
min_base: int = 2 # Minimum base (2=binary)
|
|
||||||
max_base: int = 16 # Maximum base (16=hex)
|
min_base: int = 2 # Minimum base (2=binary)
|
||||||
min_value: int = 0 # Minimum decimal value to convert
|
max_base: int = 16 # Maximum base (16=hex)
|
||||||
max_value: int = 1000 # Maximum decimal value to convert
|
min_value: int = 0 # Minimum decimal value to convert
|
||||||
|
max_value: int = 1000 # Maximum decimal value to convert
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -55,37 +58,37 @@ class BaseConversionDataset:
|
||||||
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
|
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
|
||||||
"""Generate random value and source/target bases"""
|
"""Generate random value and source/target bases"""
|
||||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Choose source and target bases
|
# Choose source and target bases
|
||||||
source_base = rng.randint(self.config.min_base, self.config.max_base)
|
source_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||||
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||||
while target_base == source_base: # Ensure different bases
|
while target_base == source_base: # Ensure different bases
|
||||||
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||||
|
|
||||||
return value, source_base, target_base
|
return value, source_base, target_base
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single base conversion task"""
|
"""Generate a single base conversion task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
value, source_base, target_base = self._generate_conversion(rng)
|
value, source_base, target_base = self._generate_conversion(rng)
|
||||||
|
|
||||||
# Convert decimal to source base representation
|
# Convert decimal to source base representation
|
||||||
source_repr = format(value, f'x' if source_base == 16 else f'b' if source_base == 2 else '').strip()
|
source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
|
||||||
if source_base not in (2, 16):
|
if source_base not in (2, 16):
|
||||||
source_repr = format(value, f'{source_base}x').lower().strip()
|
source_repr = format(value, f"{source_base}x").lower().strip()
|
||||||
|
|
||||||
# Convert decimal to target base for answer
|
# Convert decimal to target base for answer
|
||||||
target_repr = format(value, f'x' if target_base == 16 else f'b' if target_base == 2 else '').strip()
|
target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
|
||||||
if target_base not in (2, 16):
|
if target_base not in (2, 16):
|
||||||
target_repr = format(value, f'{target_base}x').lower().strip()
|
target_repr = format(value, f"{target_base}x").lower().strip()
|
||||||
|
|
||||||
source_name = self._format_base_name(source_base)
|
source_name = self._format_base_name(source_base)
|
||||||
target_name = self._format_base_name(target_base)
|
target_name = self._format_base_name(target_base)
|
||||||
|
|
||||||
# Add hint for bases > 10 about using lowercase letters
|
# Add hint for bases > 10 about using lowercase letters
|
||||||
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
|
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
|
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
|
||||||
"answer": target_repr,
|
"answer": target_repr,
|
||||||
|
|
@ -94,8 +97,8 @@ class BaseConversionDataset:
|
||||||
"source_base": source_base,
|
"source_base": source_base,
|
||||||
"target_base": target_base,
|
"target_base": target_base,
|
||||||
"source_repr": source_repr,
|
"source_repr": source_repr,
|
||||||
"target_repr": target_repr
|
"target_repr": target_repr,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
"""Letter counting task generator"""
|
"""Letter counting task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LetterCountingConfig:
|
class LetterCountingConfig:
|
||||||
"""Configuration for letter counting task generation"""
|
"""Configuration for letter counting task generation"""
|
||||||
min_words: int = 5 # Minimum words in span
|
|
||||||
max_words: int = 15 # Maximum words in span
|
min_words: int = 5 # Minimum words in span
|
||||||
|
max_words: int = 15 # Maximum words in span
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -27,11 +30,11 @@ class LetterCountingDataset:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.config.validate()
|
self.config.validate()
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
||||||
|
|
||||||
# Load and preprocess text
|
# Load and preprocess text
|
||||||
text = read_data_file("in_the_year_2889.txt")
|
text = read_data_file("in_the_year_2889.txt")
|
||||||
# Extract words and clean them to contain only alphanumeric characters
|
# Extract words and clean them to contain only alphanumeric characters
|
||||||
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()]
|
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -50,31 +53,27 @@ class LetterCountingDataset:
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single letter counting task"""
|
"""Generate a single letter counting task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Select random span of words
|
# Select random span of words
|
||||||
span_length = rng.randint(self.config.min_words, self.config.max_words)
|
span_length = rng.randint(self.config.min_words, self.config.max_words)
|
||||||
start_idx = rng.randint(0, len(self.words) - span_length)
|
start_idx = rng.randint(0, len(self.words) - span_length)
|
||||||
span = self.words[start_idx:start_idx + span_length]
|
span = self.words[start_idx : start_idx + span_length]
|
||||||
|
|
||||||
# Get all unique letters from span
|
# Get all unique letters from span
|
||||||
letters = set(''.join(span).lower())
|
letters = set("".join(span).lower())
|
||||||
if not letters:
|
if not letters:
|
||||||
letters = {'a'} # Fallback if span has no letters
|
letters = {"a"} # Fallback if span has no letters
|
||||||
|
|
||||||
# Select random letter that appears in the span
|
# Select random letter that appears in the span
|
||||||
target_letter = rng.choice(list(letters))
|
target_letter = rng.choice(list(letters))
|
||||||
|
|
||||||
# Count occurrences
|
# Count occurrences
|
||||||
count = sum(word.lower().count(target_letter) for word in span)
|
count = sum(word.lower().count(target_letter) for word in span)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
||||||
"answer": str(count),
|
"answer": str(count),
|
||||||
"metadata": {
|
"metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
|
||||||
"span_length": span_length,
|
|
||||||
"target_letter": target_letter,
|
|
||||||
"span": span
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
"""Number filtering task generator"""
|
"""Number filtering task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NumberFilteringConfig:
|
class NumberFilteringConfig:
|
||||||
"""Configuration for number filtering task generation"""
|
"""Configuration for number filtering task generation"""
|
||||||
min_numbers: int = 3 # Minimum numbers in list
|
|
||||||
max_numbers: int = 10 # Maximum numbers in list
|
min_numbers: int = 3 # Minimum numbers in list
|
||||||
min_decimals: int = 0 # Minimum decimal places
|
max_numbers: int = 10 # Maximum numbers in list
|
||||||
max_decimals: int = 4 # Maximum decimal places
|
min_decimals: int = 0 # Minimum decimal places
|
||||||
min_value: float = -100.0 # Minimum number value
|
max_decimals: int = 4 # Maximum decimal places
|
||||||
max_value: float = 100.0 # Maximum number value
|
min_value: float = -100.0 # Minimum number value
|
||||||
|
max_value: float = 100.0 # Maximum number value
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -56,23 +59,23 @@ class NumberFilteringDataset:
|
||||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = []
|
numbers = []
|
||||||
str_numbers = []
|
str_numbers = []
|
||||||
|
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||||
str_num = self._format_number(num, decimals)
|
str_num = self._format_number(num, decimals)
|
||||||
numbers.append(float(str_num)) # Convert back to simulate precision loss
|
numbers.append(float(str_num)) # Convert back to simulate precision loss
|
||||||
str_numbers.append(str_num)
|
str_numbers.append(str_num)
|
||||||
|
|
||||||
return numbers, str_numbers
|
return numbers, str_numbers
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single number filtering task"""
|
"""Generate a single number filtering task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate numbers and their string representations
|
# Generate numbers and their string representations
|
||||||
numbers, str_numbers = self._generate_numbers(rng)
|
numbers, str_numbers = self._generate_numbers(rng)
|
||||||
|
|
||||||
# Determine filter value between min and max of generated numbers
|
# Determine filter value between min and max of generated numbers
|
||||||
min_val = min(numbers)
|
min_val = min(numbers)
|
||||||
max_val = max(numbers)
|
max_val = max(numbers)
|
||||||
|
|
@ -80,31 +83,33 @@ class NumberFilteringDataset:
|
||||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||||
filter_str = self._format_number(filter_value, decimals)
|
filter_str = self._format_number(filter_value, decimals)
|
||||||
filter_value = float(filter_str) # Convert back to simulate precision loss
|
filter_value = float(filter_str) # Convert back to simulate precision loss
|
||||||
|
|
||||||
# Randomly choose filter operation
|
# Randomly choose filter operation
|
||||||
keep_larger = rng.choice([True, False])
|
keep_larger = rng.choice([True, False])
|
||||||
larger_smaller = "larger" if keep_larger else "smaller"
|
larger_smaller = "larger" if keep_larger else "smaller"
|
||||||
keep_remove = "keep" if rng.choice([True, False]) else "remove"
|
keep_remove = "keep" if rng.choice([True, False]) else "remove"
|
||||||
|
|
||||||
# Apply filter based on chosen operation
|
# Apply filter based on chosen operation
|
||||||
if keep_remove == "keep":
|
if keep_remove == "keep":
|
||||||
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
|
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
|
||||||
else: # remove
|
else: # remove
|
||||||
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
|
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
|
||||||
|
|
||||||
# Format results as strings with original precision
|
# Format results as strings with original precision
|
||||||
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
"question": (
|
||||||
f"in this list: {str_numbers}"),
|
f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
||||||
|
f"in this list: {str_numbers}"
|
||||||
|
),
|
||||||
"answer": str(result_strs) if result_strs else "[]",
|
"answer": str(result_strs) if result_strs else "[]",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_numbers": str_numbers,
|
"original_numbers": str_numbers,
|
||||||
"filter_value": filter_str,
|
"filter_value": filter_str,
|
||||||
"operation": f"{keep_remove}_{larger_smaller}",
|
"operation": f"{keep_remove}_{larger_smaller}",
|
||||||
"result": result_strs
|
"result": result_strs,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
"""Number sorting task generator"""
|
"""Number sorting task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NumberSortingConfig:
|
class NumberSortingConfig:
|
||||||
"""Configuration for number sorting task generation"""
|
"""Configuration for number sorting task generation"""
|
||||||
min_numbers: int = 3 # Minimum numbers to sort
|
|
||||||
max_numbers: int = 10 # Maximum numbers to sort
|
min_numbers: int = 3 # Minimum numbers to sort
|
||||||
min_decimals: int = 0 # Minimum decimal places
|
max_numbers: int = 10 # Maximum numbers to sort
|
||||||
max_decimals: int = 2 # Maximum decimal places
|
min_decimals: int = 0 # Minimum decimal places
|
||||||
|
max_decimals: int = 2 # Maximum decimal places
|
||||||
min_value: float = -100.0 # Minimum value
|
min_value: float = -100.0 # Minimum value
|
||||||
max_value: float = 100.0 # Maximum value
|
max_value: float = 100.0 # Maximum value
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -57,10 +60,10 @@ class NumberSortingDataset:
|
||||||
"""Generate list of numbers and their string representations"""
|
"""Generate list of numbers and their string representations"""
|
||||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||||
|
|
||||||
numbers = []
|
numbers = []
|
||||||
number_strs = []
|
number_strs = []
|
||||||
|
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||||
num_str = self._format_number(num, decimals)
|
num_str = self._format_number(num, decimals)
|
||||||
|
|
@ -68,37 +71,33 @@ class NumberSortingDataset:
|
||||||
num = float(num_str)
|
num = float(num_str)
|
||||||
numbers.append(num)
|
numbers.append(num)
|
||||||
number_strs.append(num_str)
|
number_strs.append(num_str)
|
||||||
|
|
||||||
return numbers, number_strs
|
return numbers, number_strs
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single sorting task"""
|
"""Generate a single sorting task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
numbers, number_strs = self._generate_numbers(rng)
|
numbers, number_strs = self._generate_numbers(rng)
|
||||||
|
|
||||||
# Generate both ascending and descending answers
|
# Generate both ascending and descending answers
|
||||||
asc_numbers = sorted(numbers)
|
asc_numbers = sorted(numbers)
|
||||||
desc_numbers = sorted(numbers, reverse=True)
|
desc_numbers = sorted(numbers, reverse=True)
|
||||||
|
|
||||||
# Format answers as string lists
|
# Format answers as string lists
|
||||||
decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0
|
decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
|
||||||
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
|
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
|
||||||
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
|
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
|
||||||
|
|
||||||
# Randomly choose ascending or descending
|
# Randomly choose ascending or descending
|
||||||
is_ascending = rng.choice([True, False])
|
is_ascending = rng.choice([True, False])
|
||||||
direction = "ascending" if is_ascending else "descending"
|
direction = "ascending" if is_ascending else "descending"
|
||||||
answer = asc_answer if is_ascending else desc_answer
|
answer = asc_answer if is_ascending else desc_answer
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {
|
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
||||||
"original_numbers": number_strs,
|
|
||||||
"direction": direction,
|
|
||||||
"sorted_numbers": answer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
"""Word reversal task generator"""
|
"""Word reversal task generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WordReversalConfig:
|
class WordReversalConfig:
|
||||||
"""Configuration for word reversal task generation"""
|
"""Configuration for word reversal task generation"""
|
||||||
min_words: int = 3 # Minimum words in list
|
|
||||||
max_words: int = 8 # Maximum words in list
|
min_words: int = 3 # Minimum words in list
|
||||||
|
max_words: int = 8 # Maximum words in list
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -27,11 +30,11 @@ class WordReversalDataset:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.config.validate()
|
self.config.validate()
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
||||||
|
|
||||||
# Load and preprocess text
|
# Load and preprocess text
|
||||||
text = read_data_file("in_the_year_2889.txt")
|
text = read_data_file("in_the_year_2889.txt")
|
||||||
# Extract words and clean them to contain only alphanumeric characters
|
# Extract words and clean them to contain only alphanumeric characters
|
||||||
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()]
|
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -50,23 +53,20 @@ class WordReversalDataset:
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single word reversal task"""
|
"""Generate a single word reversal task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Select random words
|
# Select random words
|
||||||
num_words = rng.randint(self.config.min_words, self.config.max_words)
|
num_words = rng.randint(self.config.min_words, self.config.max_words)
|
||||||
word_indices = rng.sample(range(len(self.words)), num_words)
|
word_indices = rng.sample(range(len(self.words)), num_words)
|
||||||
words = [self.words[i] for i in word_indices]
|
words = [self.words[i] for i in word_indices]
|
||||||
|
|
||||||
# Create question and answer
|
# Create question and answer
|
||||||
question = ", ".join(words)
|
question = ", ".join(words)
|
||||||
answer = ", ".join(reversed(words))
|
answer = ", ".join(reversed(words))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Reverse this list of words: {question}",
|
"question": f"Reverse this list of words: {question}",
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"metadata": {
|
"metadata": {"num_words": num_words, "words": words},
|
||||||
"num_words": num_words,
|
|
||||||
"words": words
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
|
||||||
|
|
||||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
||||||
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset
|
from .fraction_simplification import (
|
||||||
|
FractionSimplificationConfig,
|
||||||
|
FractionSimplificationDataset,
|
||||||
|
fraction_simplification_dataset,
|
||||||
|
)
|
||||||
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
||||||
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
||||||
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
||||||
|
|
@ -25,7 +29,7 @@ __all__ = [
|
||||||
"FractionSimplificationDataset",
|
"FractionSimplificationDataset",
|
||||||
"fraction_simplification_dataset",
|
"fraction_simplification_dataset",
|
||||||
"GCDConfig",
|
"GCDConfig",
|
||||||
"GCDDataset",
|
"GCDDataset",
|
||||||
"gcd_dataset",
|
"gcd_dataset",
|
||||||
"LCMConfig",
|
"LCMConfig",
|
||||||
"LCMDataset",
|
"LCMDataset",
|
||||||
|
|
@ -35,5 +39,5 @@ __all__ = [
|
||||||
"leg_counting_dataset",
|
"leg_counting_dataset",
|
||||||
"PrimeFactorizationConfig",
|
"PrimeFactorizationConfig",
|
||||||
"PrimeFactorizationDataset",
|
"PrimeFactorizationDataset",
|
||||||
"prime_factorization_dataset"
|
"prime_factorization_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
expression = " ".join(expression_parts)
|
expression = " ".join(expression_parts)
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
def _format_question(self, rng: Random, expression: str) -> str:
|
def _format_question(self, rng: Random, expression: str) -> str:
|
||||||
"""Format the expression according to config style"""
|
"""Format the expression according to config style"""
|
||||||
if self.config.format_style == "simple":
|
if self.config.format_style == "simple":
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
||||||
"""Generate a chain sum task
|
"""Generate a chain sum task
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Fraction simplification task generator"""
|
"""Fraction simplification task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
|
||||||
from typing import Optional, Tuple, Sequence
|
|
||||||
from ..dataset import ProceduralDataset
|
|
||||||
from math import gcd
|
from math import gcd
|
||||||
|
from random import Random
|
||||||
|
from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FractionSimplificationConfig:
|
class FractionSimplificationConfig:
|
||||||
"""Configuration for fraction simplification task generation"""
|
"""Configuration for fraction simplification task generation"""
|
||||||
min_value: int = 1 # Minimum value for numerator/denominator
|
|
||||||
max_value: int = 1000 # Maximum value for numerator/denominator
|
min_value: int = 1 # Minimum value for numerator/denominator
|
||||||
min_factor: int = 1 # Minimum multiplication factor
|
max_value: int = 1000 # Maximum value for numerator/denominator
|
||||||
max_factor: int = 100 # Maximum multiplication factor
|
min_factor: int = 1 # Minimum multiplication factor
|
||||||
|
max_factor: int = 100 # Maximum multiplication factor
|
||||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -23,7 +26,7 @@ class FractionSimplificationConfig:
|
||||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||||
assert self.min_factor >= 1, "min_factor must be at least 1"
|
assert self.min_factor >= 1, "min_factor must be at least 1"
|
||||||
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
|
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
|
||||||
|
|
||||||
# Validate styles
|
# Validate styles
|
||||||
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
|
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
|
||||||
for style in self.styles:
|
for style in self.styles:
|
||||||
|
|
@ -46,37 +49,38 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
# Generate the simplified fraction first
|
# Generate the simplified fraction first
|
||||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Make sure they're coprime by dividing by their GCD
|
# Make sure they're coprime by dividing by their GCD
|
||||||
common = gcd(simplified_num, simplified_den)
|
common = gcd(simplified_num, simplified_den)
|
||||||
simplified_num //= common
|
simplified_num //= common
|
||||||
simplified_den //= common
|
simplified_den //= common
|
||||||
|
|
||||||
# Check if simplified fraction is within bounds
|
# Check if simplified fraction is within bounds
|
||||||
if (self.config.min_value <= simplified_num <= self.config.max_value and
|
if (
|
||||||
self.config.min_value <= simplified_den <= self.config.max_value):
|
self.config.min_value <= simplified_num <= self.config.max_value
|
||||||
|
and self.config.min_value <= simplified_den <= self.config.max_value
|
||||||
|
):
|
||||||
# Ensure numerator is smaller than denominator
|
# Ensure numerator is smaller than denominator
|
||||||
if simplified_num > simplified_den:
|
if simplified_num > simplified_den:
|
||||||
simplified_num, simplified_den = simplified_den, simplified_num
|
simplified_num, simplified_den = simplified_den, simplified_num
|
||||||
|
|
||||||
# Multiply both by a random factor to create the unsimplified version
|
# Multiply both by a random factor to create the unsimplified version
|
||||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||||
numerator = simplified_num * factor
|
numerator = simplified_num * factor
|
||||||
denominator = simplified_den * factor
|
denominator = simplified_den * factor
|
||||||
return numerator, denominator, simplified_num, simplified_den
|
return numerator, denominator, simplified_num, simplified_den
|
||||||
|
|
||||||
# If we failed to find a good fraction after max attempts,
|
# If we failed to find a good fraction after max attempts,
|
||||||
# generate one that's guaranteed to be within bounds
|
# generate one that's guaranteed to be within bounds
|
||||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Ensure numerator is smaller than denominator
|
# Ensure numerator is smaller than denominator
|
||||||
if simplified_num > simplified_den:
|
if simplified_num > simplified_den:
|
||||||
simplified_num, simplified_den = simplified_den, simplified_num
|
simplified_num, simplified_den = simplified_den, simplified_num
|
||||||
|
|
||||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||||
return (simplified_num * factor, simplified_den * factor,
|
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||||
simplified_num, simplified_den)
|
|
||||||
|
|
||||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||||
"""Format a fraction in various styles"""
|
"""Format a fraction in various styles"""
|
||||||
|
|
@ -95,16 +99,16 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single fraction simplification task"""
|
"""Generate a single fraction simplification task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||||
|
|
||||||
# Choose a random style from configured styles
|
# Choose a random style from configured styles
|
||||||
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)]
|
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||||
|
|
||||||
# Format both question and answer in the same style
|
# Format both question and answer in the same style
|
||||||
question_fraction = self._format_fraction(num, den, style)
|
question_fraction = self._format_fraction(num, den, style)
|
||||||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
||||||
"answer": answer_fraction,
|
"answer": answer_fraction,
|
||||||
|
|
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
"simplified_numerator": simple_num,
|
"simplified_numerator": simple_num,
|
||||||
"simplified_denominator": simple_den,
|
"simplified_denominator": simple_den,
|
||||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||||
"style": style
|
"style": style,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Greatest Common Divisor (GCD) task generator"""
|
"""Greatest Common Divisor (GCD) task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
from math import gcd
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GCDConfig:
|
class GCDConfig:
|
||||||
"""Configuration for GCD task generation"""
|
"""Configuration for GCD task generation"""
|
||||||
min_numbers: int = 2 # Minimum numbers to find GCD of
|
|
||||||
max_numbers: int = 2 # Maximum numbers to find GCD of
|
min_numbers: int = 2 # Minimum numbers to find GCD of
|
||||||
min_value: int = 1 # Minimum value for each number
|
max_numbers: int = 2 # Maximum numbers to find GCD of
|
||||||
max_value: int = 1000 # Maximum value for each number
|
min_value: int = 1 # Minimum value for each number
|
||||||
|
max_value: int = 1000 # Maximum value for each number
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -38,33 +41,28 @@ class GCDDataset(ProceduralDataset):
|
||||||
Will try up to 3 times to find numbers with GCD > 1."""
|
Will try up to 3 times to find numbers with GCD > 1."""
|
||||||
for _ in range(3): # Try up to 3 times to get GCD > 1
|
for _ in range(3): # Try up to 3 times to get GCD > 1
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(gcd, numbers)
|
result = reduce(gcd, numbers)
|
||||||
if result > 1:
|
if result > 1:
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
# If we failed to find GCD > 1 after 3 tries, generate one final set
|
# If we failed to find GCD > 1 after 3 tries, generate one final set
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(gcd, numbers)
|
result = reduce(gcd, numbers)
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single GCD task"""
|
"""Generate a single GCD task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
numbers, result = self._generate_numbers(rng)
|
numbers, result = self._generate_numbers(rng)
|
||||||
numbers_str = ", ".join(str(n) for n in numbers)
|
numbers_str = ", ".join(str(n) for n in numbers)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {"numbers": numbers, "result": result},
|
||||||
"numbers": numbers,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,24 @@
|
||||||
"""Least Common Multiple (LCM) task generator"""
|
"""Least Common Multiple (LCM) task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import lcm
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
from math import lcm
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LCMConfig:
|
class LCMConfig:
|
||||||
"""Configuration for LCM task generation"""
|
"""Configuration for LCM task generation"""
|
||||||
min_numbers: int = 2 # Minimum numbers to find LCM of
|
|
||||||
max_numbers: int = 2 # Maximum numbers to find LCM of
|
min_numbers: int = 2 # Minimum numbers to find LCM of
|
||||||
min_value: int = 1 # Minimum value for each number
|
max_numbers: int = 2 # Maximum numbers to find LCM of
|
||||||
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
min_value: int = 1 # Minimum value for each number
|
||||||
|
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -36,38 +39,34 @@ class LCMDataset(ProceduralDataset):
|
||||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||||
"""Generate a list of random positive integers and their LCM.
|
"""Generate a list of random positive integers and their LCM.
|
||||||
Will try up to 3 times to find numbers with LCM < product."""
|
Will try up to 3 times to find numbers with LCM < product."""
|
||||||
|
|
||||||
def calculate_product(nums: List[int]) -> int:
|
def calculate_product(nums: List[int]) -> int:
|
||||||
return reduce(lambda x, y: x * y, nums)
|
return reduce(lambda x, y: x * y, nums)
|
||||||
|
|
||||||
for _ in range(3): # Try up to 3 times to get LCM < product
|
for _ in range(3): # Try up to 3 times to get LCM < product
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(lcm, numbers)
|
result = reduce(lcm, numbers)
|
||||||
if result < calculate_product(numbers):
|
if result < calculate_product(numbers):
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
# If we failed to find LCM < product after 3 tries, generate one final set
|
# If we failed to find LCM < product after 3 tries, generate one final set
|
||||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||||
for _ in range(num_count)]
|
|
||||||
result = reduce(lcm, numbers)
|
result = reduce(lcm, numbers)
|
||||||
return numbers, result
|
return numbers, result
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single LCM task"""
|
"""Generate a single LCM task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
numbers, result = self._generate_numbers(rng)
|
numbers, result = self._generate_numbers(rng)
|
||||||
numbers_str = ", ".join(str(n) for n in numbers)
|
numbers_str = ", ".join(str(n) for n in numbers)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {"numbers": numbers, "result": result},
|
||||||
"numbers": numbers,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
"""Leg counting task generator"""
|
"""Leg counting task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
ANIMALS = {
|
ANIMALS = {
|
||||||
|
|
@ -52,14 +54,16 @@ ANIMALS = {
|
||||||
"woodlouse": 14,
|
"woodlouse": 14,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LegCountingConfig:
|
class LegCountingConfig:
|
||||||
"""Configuration for leg counting task generation"""
|
"""Configuration for leg counting task generation"""
|
||||||
min_animals: int = 2 # Minimum number of animals in problem
|
|
||||||
max_animals: int = 5 # Maximum number of animals
|
min_animals: int = 2 # Minimum number of animals in problem
|
||||||
max_instances: int = 3 # Maximum instances of each animal
|
max_animals: int = 5 # Maximum number of animals
|
||||||
|
max_instances: int = 3 # Maximum instances of each animal
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -80,39 +84,36 @@ class LegCountingDataset(ProceduralDataset):
|
||||||
"""Generate a random set of animals and their counts"""
|
"""Generate a random set of animals and their counts"""
|
||||||
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
||||||
animals = {}
|
animals = {}
|
||||||
|
|
||||||
# Select random animals
|
# Select random animals
|
||||||
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
|
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
|
||||||
for animal in selected_animals:
|
for animal in selected_animals:
|
||||||
count = rng.randint(1, self.config.max_instances)
|
count = rng.randint(1, self.config.max_instances)
|
||||||
animals[animal] = count
|
animals[animal] = count
|
||||||
|
|
||||||
return animals
|
return animals
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single leg counting task"""
|
"""Generate a single leg counting task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate random animals and their counts
|
# Generate random animals and their counts
|
||||||
animals = self._generate_animals(rng)
|
animals = self._generate_animals(rng)
|
||||||
|
|
||||||
# Calculate total legs
|
# Calculate total legs
|
||||||
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
||||||
|
|
||||||
# Format animal counts for question
|
# Format animal counts for question
|
||||||
animal_list = []
|
animal_list = []
|
||||||
for animal, count in animals.items():
|
for animal, count in animals.items():
|
||||||
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
||||||
|
|
||||||
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": question,
|
||||||
"answer": str(total_legs),
|
"answer": str(total_legs),
|
||||||
"metadata": {
|
"metadata": {"animals": animals, "total_legs": total_legs},
|
||||||
"animals": animals,
|
|
||||||
"total_legs": total_legs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,20 @@
|
||||||
"""Prime factorization task generator"""
|
"""Prime factorization task generator"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrimeFactorizationConfig:
|
class PrimeFactorizationConfig:
|
||||||
"""Configuration for prime factorization task generation"""
|
"""Configuration for prime factorization task generation"""
|
||||||
min_value: int = 2 # Minimum number to factorize
|
|
||||||
max_value: int = 1000 # Maximum number to factorize
|
min_value: int = 2 # Minimum number to factorize
|
||||||
|
max_value: int = 1000 # Maximum number to factorize
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -44,24 +48,23 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single prime factorization task"""
|
"""Generate a single prime factorization task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate random number to factorize
|
# Generate random number to factorize
|
||||||
number = rng.randint(self.config.min_value, self.config.max_value)
|
number = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Calculate prime factors
|
# Calculate prime factors
|
||||||
factors = self._prime_factors(number)
|
factors = self._prime_factors(number)
|
||||||
|
|
||||||
# Format answer as multiplication of prime factors
|
# Format answer as multiplication of prime factors
|
||||||
answer = " × ".join(map(str, factors))
|
answer = " × ".join(map(str, factors))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": (f"Find the prime factorization of {number}. Write the factors separated by × "
|
"question": (
|
||||||
f"(Example: for 12 the answer would be: 2 × 2 × 3)"),
|
f"Find the prime factorization of {number}. Write the factors separated by × "
|
||||||
|
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
|
||||||
|
),
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"metadata": {
|
"metadata": {"number": number, "factors": factors},
|
||||||
"number": number,
|
|
||||||
"factors": factors
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class SequenceConfig:
|
||||||
class PatternRule:
|
class PatternRule:
|
||||||
"""Represents a composable sequence pattern rule"""
|
"""Represents a composable sequence pattern rule"""
|
||||||
|
|
||||||
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None):
|
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List["PatternRule"] = None):
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
self.subrules = subrules or []
|
self.subrules = subrules or []
|
||||||
|
|
@ -66,14 +66,14 @@ class PatternRule:
|
||||||
elif op == Operation.COMPOSE:
|
elif op == Operation.COMPOSE:
|
||||||
# Apply each subrule in sequence, passing the result through
|
# Apply each subrule in sequence, passing the result through
|
||||||
for subrule in self.subrules:
|
for subrule in self.subrules:
|
||||||
temp_sequence = sequence[:position + 1]
|
temp_sequence = sequence[: position + 1]
|
||||||
temp_sequence[-1] = result # Use current result as input
|
temp_sequence[-1] = result # Use current result as input
|
||||||
result = subrule.apply(temp_sequence, position)
|
result = subrule.apply(temp_sequence, position)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compose(cls, rules: List['PatternRule']) -> 'PatternRule':
|
def compose(cls, rules: List["PatternRule"]) -> "PatternRule":
|
||||||
"""Create a new rule that composes multiple rules together"""
|
"""Create a new rule that composes multiple rules together"""
|
||||||
return cls([Operation.COMPOSE], [0], subrules=rules)
|
return cls([Operation.COMPOSE], [0], subrules=rules)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,34 +4,37 @@ from importlib import resources
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
def get_data_file_path(filename: str) -> Path:
|
def get_data_file_path(filename: str) -> Path:
|
||||||
"""Get the path to a data file in the package.
|
"""Get the path to a data file in the package.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of the file in the data directory
|
filename: Name of the file in the data directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path object pointing to the data file
|
Path object pointing to the data file
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> path = get_data_file_path("pg19362.txt")
|
>>> path = get_data_file_path("pg19362.txt")
|
||||||
>>> with open(path) as f:
|
>>> with open(path) as f:
|
||||||
... content = f.read()
|
... content = f.read()
|
||||||
"""
|
"""
|
||||||
return resources.files('reasoning_gym.data').joinpath(filename)
|
return resources.files("reasoning_gym.data").joinpath(filename)
|
||||||
|
|
||||||
|
|
||||||
def read_data_file(filename: str) -> str:
|
def read_data_file(filename: str) -> str:
|
||||||
"""Read the contents of a data file in the package.
|
"""Read the contents of a data file in the package.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of the file in the data directory
|
filename: Name of the file in the data directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String contents of the file
|
String contents of the file
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> content = read_data_file("pg19362.txt")
|
>>> content = read_data_file("pg19362.txt")
|
||||||
"""
|
"""
|
||||||
return resources.files('reasoning_gym.data').joinpath(filename).read_text()
|
return resources.files("reasoning_gym.data").joinpath(filename).read_text()
|
||||||
|
|
||||||
__all__ = ['get_data_file_path', 'read_data_file']
|
|
||||||
|
__all__ = ["get_data_file_path", "read_data_file"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
The Project Gutenberg eBook of In the year 2889
|
The Project Gutenberg eBook of In the year 2889
|
||||||
|
|
||||||
This ebook is for the use of anyone anywhere in the United States and
|
This ebook is for the use of anyone anywhere in the United States and
|
||||||
most other parts of the world at no cost and with almost no restrictions
|
most other parts of the world at no cost and with almost no restrictions
|
||||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||||
|
|
@ -702,7 +702,7 @@ End of Project Gutenberg's In the Year 2889, by Jules Verne and Michel Verne
|
||||||
*** END OF THE PROJECT GUTENBERG EBOOK IN THE YEAR 2889 ***
|
*** END OF THE PROJECT GUTENBERG EBOOK IN THE YEAR 2889 ***
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Updated editions will replace the previous one—the old editions will
|
Updated editions will replace the previous one—the old editions will
|
||||||
be renamed.
|
be renamed.
|
||||||
|
|
@ -807,7 +807,7 @@ performed, viewed, copied or distributed:
|
||||||
at www.gutenberg.org. If you
|
at www.gutenberg.org. If you
|
||||||
are not located in the United States, you will have to check the laws
|
are not located in the United States, you will have to check the laws
|
||||||
of the country where you are located before using this eBook.
|
of the country where you are located before using this eBook.
|
||||||
|
|
||||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||||
derived from texts not protected by U.S. copyright law (does not
|
derived from texts not protected by U.S. copyright law (does not
|
||||||
contain a notice indicating that it is posted with permission of the
|
contain a notice indicating that it is posted with permission of the
|
||||||
|
|
@ -869,7 +869,7 @@ provided that:
|
||||||
Gutenberg Literary Archive Foundation at the address specified in
|
Gutenberg Literary Archive Foundation at the address specified in
|
||||||
Section 4, “Information about donations to the Project Gutenberg
|
Section 4, “Information about donations to the Project Gutenberg
|
||||||
Literary Archive Foundation.”
|
Literary Archive Foundation.”
|
||||||
|
|
||||||
• You provide a full refund of any money paid by a user who notifies
|
• You provide a full refund of any money paid by a user who notifies
|
||||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||||
does not agree to the terms of the full Project Gutenberg™
|
does not agree to the terms of the full Project Gutenberg™
|
||||||
|
|
@ -877,15 +877,15 @@ provided that:
|
||||||
copies of the works possessed in a physical medium and discontinue
|
copies of the works possessed in a physical medium and discontinue
|
||||||
all use of and all access to other copies of Project Gutenberg™
|
all use of and all access to other copies of Project Gutenberg™
|
||||||
works.
|
works.
|
||||||
|
|
||||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||||
any money paid for a work or a replacement copy, if a defect in the
|
any money paid for a work or a replacement copy, if a defect in the
|
||||||
electronic work is discovered and reported to you within 90 days of
|
electronic work is discovered and reported to you within 90 days of
|
||||||
receipt of the work.
|
receipt of the work.
|
||||||
|
|
||||||
• You comply with all other terms of this agreement for free
|
• You comply with all other terms of this agreement for free
|
||||||
distribution of Project Gutenberg™ works.
|
distribution of Project Gutenberg™ works.
|
||||||
|
|
||||||
|
|
||||||
1.E.9. If you wish to charge a fee or distribute a Project
|
1.E.9. If you wish to charge a fee or distribute a Project
|
||||||
Gutenberg™ electronic work or group of works on different terms than
|
Gutenberg™ electronic work or group of works on different terms than
|
||||||
|
|
@ -1048,5 +1048,3 @@ This website includes information about Project Gutenberg™,
|
||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,28 @@
|
||||||
"""Base class for procedural dataset generators"""
|
"""Base class for procedural dataset generators"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sized, Iterable
|
from collections.abc import Iterable, Sized
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Iterator, Dict, Any
|
from typing import Any, Dict, Iterator, Optional
|
||||||
|
|
||||||
|
|
||||||
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||||
"""Abstract base class for procedural dataset generators"""
|
"""Abstract base class for procedural dataset generators"""
|
||||||
|
|
||||||
def __init__(self, seed: Optional[int] = None, size: int = 500):
|
def __init__(self, seed: Optional[int] = None, size: int = 500):
|
||||||
"""Initialize the dataset with optional seed and size"""
|
"""Initialize the dataset with optional seed and size"""
|
||||||
self.size = size
|
self.size = size
|
||||||
self.seed = seed if seed is not None else Random().randint(0, 2**32)
|
self.seed = seed if seed is not None else Random().randint(0, 2**32)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the virtual size of the dataset"""
|
"""Return the virtual size of the dataset"""
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Make the dataset iterable"""
|
"""Make the dataset iterable"""
|
||||||
self._current_idx = 0
|
self._current_idx = 0
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self) -> Dict[str, Any]:
|
def __next__(self) -> Dict[str, Any]:
|
||||||
"""Get next item in iteration"""
|
"""Get next item in iteration"""
|
||||||
if self._current_idx >= self.size:
|
if self._current_idx >= self.size:
|
||||||
|
|
@ -29,14 +30,14 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||||
item = self[self._current_idx]
|
item = self[self._current_idx]
|
||||||
self._current_idx += 1
|
self._current_idx += 1
|
||||||
return item
|
return item
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single dataset item
|
"""Generate a single dataset item
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx: Index of the item to generate
|
idx: Index of the item to generate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict containing at least:
|
dict containing at least:
|
||||||
- question: str
|
- question: str
|
||||||
|
|
|
||||||
|
|
@ -14,5 +14,5 @@ __all__ = [
|
||||||
"mini_sudoku_dataset",
|
"mini_sudoku_dataset",
|
||||||
"SudokuConfig",
|
"SudokuConfig",
|
||||||
"SudokuDataset",
|
"SudokuDataset",
|
||||||
"sudoku_dataset"
|
"sudoku_dataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,19 @@
|
||||||
"""Mini Sudoku (4x4) puzzle generator"""
|
"""Mini Sudoku (4x4) puzzle generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MiniSudokuConfig:
|
class MiniSudokuConfig:
|
||||||
"""Configuration for 4x4 sudoku puzzle generation"""
|
"""Configuration for 4x4 sudoku puzzle generation"""
|
||||||
min_empty: int = 8 # Minimum number of empty cells
|
|
||||||
max_empty: int = 12 # Maximum number of empty cells
|
min_empty: int = 8 # Minimum number of empty cells
|
||||||
|
max_empty: int = 12 # Maximum number of empty cells
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -45,11 +48,11 @@ class MiniSudokuDataset:
|
||||||
# Check row
|
# Check row
|
||||||
if num in board[row]:
|
if num in board[row]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check column
|
# Check column
|
||||||
if num in [board[i][col] for i in range(4)]:
|
if num in [board[i][col] for i in range(4)]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check 2x2 box
|
# Check 2x2 box
|
||||||
box_row, box_col = 2 * (row // 2), 2 * (col // 2)
|
box_row, box_col = 2 * (row // 2), 2 * (col // 2)
|
||||||
for i in range(box_row, box_row + 2):
|
for i in range(box_row, box_row + 2):
|
||||||
|
|
@ -63,7 +66,7 @@ class MiniSudokuDataset:
|
||||||
empty = self._find_empty(board)
|
empty = self._find_empty(board)
|
||||||
if not empty:
|
if not empty:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
row, col = empty
|
row, col = empty
|
||||||
for num in range(1, 5):
|
for num in range(1, 5):
|
||||||
if self._is_valid(board, row, col, num):
|
if self._is_valid(board, row, col, num):
|
||||||
|
|
@ -84,7 +87,7 @@ class MiniSudokuDataset:
|
||||||
def _generate_solved_board(self, rng: Random) -> List[List[int]]:
|
def _generate_solved_board(self, rng: Random) -> List[List[int]]:
|
||||||
"""Generate a complete solved mini sudoku board"""
|
"""Generate a complete solved mini sudoku board"""
|
||||||
board = [[0] * 4 for _ in range(4)]
|
board = [[0] * 4 for _ in range(4)]
|
||||||
|
|
||||||
# Try multiple times to generate a valid board
|
# Try multiple times to generate a valid board
|
||||||
max_attempts = 100
|
max_attempts = 100
|
||||||
for _ in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
|
|
@ -92,7 +95,7 @@ class MiniSudokuDataset:
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
for j in range(4):
|
for j in range(4):
|
||||||
board[i][j] = 0
|
board[i][j] = 0
|
||||||
|
|
||||||
# Fill diagonal boxes first (they are independent)
|
# Fill diagonal boxes first (they are independent)
|
||||||
for i in range(0, 4, 2):
|
for i in range(0, 4, 2):
|
||||||
nums = list(range(1, 5))
|
nums = list(range(1, 5))
|
||||||
|
|
@ -102,11 +105,11 @@ class MiniSudokuDataset:
|
||||||
for c in range(i, i + 2):
|
for c in range(i, i + 2):
|
||||||
board[r][c] = nums[pos]
|
board[r][c] = nums[pos]
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
# Try to solve the rest
|
# Try to solve the rest
|
||||||
if self._solve(board):
|
if self._solve(board):
|
||||||
return board
|
return board
|
||||||
|
|
||||||
raise RuntimeError("Failed to generate valid mini sudoku board")
|
raise RuntimeError("Failed to generate valid mini sudoku board")
|
||||||
|
|
||||||
def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]:
|
def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]:
|
||||||
|
|
@ -114,10 +117,10 @@ class MiniSudokuDataset:
|
||||||
puzzle = [row[:] for row in solved_board]
|
puzzle = [row[:] for row in solved_board]
|
||||||
cells = [(i, j) for i in range(4) for j in range(4)]
|
cells = [(i, j) for i in range(4) for j in range(4)]
|
||||||
rng.shuffle(cells)
|
rng.shuffle(cells)
|
||||||
|
|
||||||
for i, j in cells[:num_empty]:
|
for i, j in cells[:num_empty]:
|
||||||
puzzle[i][j] = 0
|
puzzle[i][j] = 0
|
||||||
|
|
||||||
return puzzle
|
return puzzle
|
||||||
|
|
||||||
def _board_to_string(self, board: List[List[int]]) -> str:
|
def _board_to_string(self, board: List[List[int]]) -> str:
|
||||||
|
|
@ -127,26 +130,22 @@ class MiniSudokuDataset:
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single mini sudoku puzzle"""
|
"""Generate a single mini sudoku puzzle"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate solved board
|
# Generate solved board
|
||||||
solved_board = self._generate_solved_board(rng)
|
solved_board = self._generate_solved_board(rng)
|
||||||
|
|
||||||
# Create puzzle by removing numbers
|
# Create puzzle by removing numbers
|
||||||
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
|
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
|
||||||
puzzle = self._create_puzzle(solved_board, num_empty, rng)
|
puzzle = self._create_puzzle(solved_board, num_empty, rng)
|
||||||
|
|
||||||
# Format as strings
|
# Format as strings
|
||||||
puzzle_str = self._board_to_string(puzzle)
|
puzzle_str = self._board_to_string(puzzle)
|
||||||
solution_str = self._board_to_string(solved_board)
|
solution_str = self._board_to_string(solved_board)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
|
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
|
||||||
"answer": solution_str,
|
"answer": solution_str,
|
||||||
"metadata": {
|
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
||||||
"puzzle": puzzle,
|
|
||||||
"solution": solved_board,
|
|
||||||
"num_empty": num_empty
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,19 @@
|
||||||
"""Sudoku puzzle generator"""
|
"""Sudoku puzzle generator"""
|
||||||
from dataclasses import dataclass
|
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SudokuConfig:
|
class SudokuConfig:
|
||||||
"""Configuration for sudoku puzzle generation"""
|
"""Configuration for sudoku puzzle generation"""
|
||||||
min_empty: int = 30 # Minimum number of empty cells
|
|
||||||
max_empty: int = 50 # Maximum number of empty cells
|
min_empty: int = 30 # Minimum number of empty cells
|
||||||
|
max_empty: int = 50 # Maximum number of empty cells
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
|
|
@ -45,11 +48,11 @@ class SudokuDataset:
|
||||||
# Check row
|
# Check row
|
||||||
if num in board[row]:
|
if num in board[row]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check column
|
# Check column
|
||||||
if num in [board[i][col] for i in range(9)]:
|
if num in [board[i][col] for i in range(9)]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check 3x3 box
|
# Check 3x3 box
|
||||||
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
|
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
|
||||||
for i in range(box_row, box_row + 3):
|
for i in range(box_row, box_row + 3):
|
||||||
|
|
@ -63,7 +66,7 @@ class SudokuDataset:
|
||||||
empty = self._find_empty(board)
|
empty = self._find_empty(board)
|
||||||
if not empty:
|
if not empty:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
row, col = empty
|
row, col = empty
|
||||||
for num in range(1, 10):
|
for num in range(1, 10):
|
||||||
if self._is_valid(board, row, col, num):
|
if self._is_valid(board, row, col, num):
|
||||||
|
|
@ -84,7 +87,7 @@ class SudokuDataset:
|
||||||
def _generate_solved_board(self, rng: Random) -> List[List[int]]:
|
def _generate_solved_board(self, rng: Random) -> List[List[int]]:
|
||||||
"""Generate a complete solved sudoku board"""
|
"""Generate a complete solved sudoku board"""
|
||||||
board = [[0] * 9 for _ in range(9)]
|
board = [[0] * 9 for _ in range(9)]
|
||||||
|
|
||||||
# Fill diagonal boxes first (they are independent)
|
# Fill diagonal boxes first (they are independent)
|
||||||
for i in range(0, 9, 3):
|
for i in range(0, 9, 3):
|
||||||
nums = list(range(1, 10))
|
nums = list(range(1, 10))
|
||||||
|
|
@ -94,7 +97,7 @@ class SudokuDataset:
|
||||||
for c in range(i, i + 3):
|
for c in range(i, i + 3):
|
||||||
board[r][c] = nums[pos]
|
board[r][c] = nums[pos]
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
# Solve the rest
|
# Solve the rest
|
||||||
self._solve(board)
|
self._solve(board)
|
||||||
return board
|
return board
|
||||||
|
|
@ -104,10 +107,10 @@ class SudokuDataset:
|
||||||
puzzle = [row[:] for row in solved_board]
|
puzzle = [row[:] for row in solved_board]
|
||||||
cells = [(i, j) for i in range(9) for j in range(9)]
|
cells = [(i, j) for i in range(9) for j in range(9)]
|
||||||
rng.shuffle(cells)
|
rng.shuffle(cells)
|
||||||
|
|
||||||
for i, j in cells[:num_empty]:
|
for i, j in cells[:num_empty]:
|
||||||
puzzle[i][j] = 0
|
puzzle[i][j] = 0
|
||||||
|
|
||||||
return puzzle
|
return puzzle
|
||||||
|
|
||||||
def _board_to_string(self, board: List[List[int]]) -> str:
|
def _board_to_string(self, board: List[List[int]]) -> str:
|
||||||
|
|
@ -117,26 +120,22 @@ class SudokuDataset:
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single sudoku puzzle"""
|
"""Generate a single sudoku puzzle"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Generate solved board
|
# Generate solved board
|
||||||
solved_board = self._generate_solved_board(rng)
|
solved_board = self._generate_solved_board(rng)
|
||||||
|
|
||||||
# Create puzzle by removing numbers
|
# Create puzzle by removing numbers
|
||||||
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
|
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
|
||||||
puzzle = self._create_puzzle(solved_board, num_empty, rng)
|
puzzle = self._create_puzzle(solved_board, num_empty, rng)
|
||||||
|
|
||||||
# Format as strings
|
# Format as strings
|
||||||
puzzle_str = self._board_to_string(puzzle)
|
puzzle_str = self._board_to_string(puzzle)
|
||||||
solution_str = self._board_to_string(solved_board)
|
solution_str = self._board_to_string(solved_board)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
|
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
|
||||||
"answer": solution_str,
|
"answer": solution_str,
|
||||||
"metadata": {
|
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
||||||
"puzzle": puzzle,
|
|
||||||
"solution": solved_board,
|
|
||||||
"num_empty": num_empty
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import pytest
|
|
||||||
from random import Random
|
from random import Random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -8,11 +10,11 @@ def test_arithmetic_dataset_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = BasicArithmeticDatasetConfig(min_terms=0)
|
config = BasicArithmeticDatasetConfig(min_terms=0)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2)
|
config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator
|
config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -23,30 +25,23 @@ def test_arithmetic_dataset_deterministic():
|
||||||
config = BasicArithmeticDatasetConfig(seed=42, size=10)
|
config = BasicArithmeticDatasetConfig(seed=42, size=10)
|
||||||
dataset1 = BasicArithmeticDataset(config)
|
dataset1 = BasicArithmeticDataset(config)
|
||||||
dataset2 = BasicArithmeticDataset(config)
|
dataset2 = BasicArithmeticDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
def test_arithmetic_dataset_items():
|
def test_arithmetic_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = BasicArithmeticDatasetConfig(
|
config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=4,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
assert isinstance(item, dict)
|
assert isinstance(item, dict)
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Verify the answer matches the expression
|
# Verify the answer matches the expression
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
answer = eval(expression) # Safe here as we control the expression
|
answer = eval(expression) # Safe here as we control the expression
|
||||||
|
|
@ -62,11 +57,11 @@ def test_arithmetic_dataset_format_styles():
|
||||||
min_terms=2,
|
min_terms=2,
|
||||||
max_terms=3, # Keep expressions simple for testing
|
max_terms=3, # Keep expressions simple for testing
|
||||||
min_digits=1,
|
min_digits=1,
|
||||||
max_digits=2
|
max_digits=2,
|
||||||
)
|
)
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
assert all(item["question"].endswith("=") for item in dataset)
|
assert all(item["question"].endswith("=") for item in dataset)
|
||||||
|
|
||||||
config.format_style = "natural"
|
config.format_style = "natural"
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
assert all("=" not in item["question"] for item in dataset)
|
assert all("=" not in item["question"] for item in dataset)
|
||||||
|
|
@ -74,24 +69,19 @@ def test_arithmetic_dataset_format_styles():
|
||||||
|
|
||||||
def test_arithmetic_dataset_iteration():
|
def test_arithmetic_dataset_iteration():
|
||||||
"""Test that iteration respects dataset size"""
|
"""Test that iteration respects dataset size"""
|
||||||
config = BasicArithmeticDatasetConfig(
|
config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
size=5, # Small size for testing
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
# Test list conversion
|
# Test list conversion
|
||||||
items = list(dataset)
|
items = list(dataset)
|
||||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
# Test multiple iterations
|
# Test multiple iterations
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for base conversion task generation"""
|
"""Tests for base conversion task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.base_conversion import (
|
from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||||
BaseConversionConfig,
|
|
||||||
BaseConversionDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_conversion_config_validation():
|
def test_base_conversion_config_validation():
|
||||||
|
|
@ -38,14 +36,7 @@ def test_base_conversion_dataset_deterministic():
|
||||||
|
|
||||||
def test_base_conversion_dataset_items():
|
def test_base_conversion_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = BaseConversionConfig(
|
config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
|
||||||
min_base=2,
|
|
||||||
max_base=16,
|
|
||||||
min_value=0,
|
|
||||||
max_value=1000,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -55,28 +46,28 @@ def test_base_conversion_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "decimal_value" in item["metadata"]
|
assert "decimal_value" in item["metadata"]
|
||||||
assert "source_base" in item["metadata"]
|
assert "source_base" in item["metadata"]
|
||||||
assert "target_base" in item["metadata"]
|
assert "target_base" in item["metadata"]
|
||||||
assert "source_repr" in item["metadata"]
|
assert "source_repr" in item["metadata"]
|
||||||
assert "target_repr" in item["metadata"]
|
assert "target_repr" in item["metadata"]
|
||||||
|
|
||||||
# Verify value range
|
# Verify value range
|
||||||
assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value
|
assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value
|
||||||
|
|
||||||
# Verify base range
|
# Verify base range
|
||||||
assert config.min_base <= item["metadata"]["source_base"] <= config.max_base
|
assert config.min_base <= item["metadata"]["source_base"] <= config.max_base
|
||||||
assert config.min_base <= item["metadata"]["target_base"] <= config.max_base
|
assert config.min_base <= item["metadata"]["target_base"] <= config.max_base
|
||||||
assert item["metadata"]["source_base"] != item["metadata"]["target_base"]
|
assert item["metadata"]["source_base"] != item["metadata"]["target_base"]
|
||||||
|
|
||||||
# Verify conversion correctness
|
# Verify conversion correctness
|
||||||
decimal_value = item["metadata"]["decimal_value"]
|
decimal_value = item["metadata"]["decimal_value"]
|
||||||
target_base = item["metadata"]["target_base"]
|
target_base = item["metadata"]["target_base"]
|
||||||
expected = format(decimal_value, 'x' if target_base == 16 else 'b' if target_base == 2 else '').strip()
|
expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip()
|
||||||
if target_base not in (2, 16):
|
if target_base not in (2, 16):
|
||||||
expected = format(decimal_value, f'{target_base}x').lower().strip()
|
expected = format(decimal_value, f"{target_base}x").lower().strip()
|
||||||
assert item["answer"] == expected
|
assert item["answer"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -100,24 +91,24 @@ def test_base_conversion_special_bases():
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=255, # Use small range for predictable results
|
max_value=255, # Use small range for predictable results
|
||||||
size=100,
|
size=100,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
binary_found = False
|
binary_found = False
|
||||||
hex_found = False
|
hex_found = False
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
if item["metadata"]["target_base"] == 2:
|
if item["metadata"]["target_base"] == 2:
|
||||||
binary_found = True
|
binary_found = True
|
||||||
# Verify binary format
|
# Verify binary format
|
||||||
assert all(c in '01' for c in item["answer"])
|
assert all(c in "01" for c in item["answer"])
|
||||||
elif item["metadata"]["target_base"] == 16:
|
elif item["metadata"]["target_base"] == 16:
|
||||||
hex_found = True
|
hex_found = True
|
||||||
# Verify hex format
|
# Verify hex format
|
||||||
assert all(c in '0123456789abcdef' for c in item["answer"])
|
assert all(c in "0123456789abcdef" for c in item["answer"])
|
||||||
|
|
||||||
assert binary_found, "No binary conversion tasks generated"
|
assert binary_found, "No binary conversion tasks generated"
|
||||||
assert hex_found, "No hexadecimal conversion tasks generated"
|
assert hex_found, "No hexadecimal conversion tasks generated"
|
||||||
|
|
||||||
|
|
@ -130,10 +121,10 @@ def test_base_conversion_formatting():
|
||||||
min_value=10, # Ensure multi-digit numbers
|
min_value=10, # Ensure multi-digit numbers
|
||||||
max_value=1000,
|
max_value=1000,
|
||||||
size=10,
|
size=10,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = BaseConversionDataset(config)
|
dataset = BaseConversionDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
# Verify lowercase letters are used
|
# Verify lowercase letters are used
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7,7 +8,7 @@ def test_chain_sum_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = ChainSumConfig(min_terms=0)
|
config = ChainSumConfig(min_terms=0)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = ChainSumConfig(min_terms=3, max_terms=2)
|
config = ChainSumConfig(min_terms=3, max_terms=2)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -18,34 +19,27 @@ def test_chain_sum_deterministic():
|
||||||
config = ChainSumConfig(seed=42, size=10)
|
config = ChainSumConfig(seed=42, size=10)
|
||||||
dataset1 = ChainSum(config)
|
dataset1 = ChainSum(config)
|
||||||
dataset2 = ChainSum(config)
|
dataset2 = ChainSum(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_sum_items():
|
def test_chain_sum_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=4,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
assert isinstance(item, dict)
|
assert isinstance(item, dict)
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Verify only + and - are used
|
# Verify only + and - are used
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
assert all(op in ["+", "-", " "] or op.isdigit() for op in expression)
|
assert all(op in ["+", "-", " "] or op.isdigit() for op in expression)
|
||||||
|
|
||||||
# Verify the answer matches the expression
|
# Verify the answer matches the expression
|
||||||
answer = eval(expression) # Safe here as we control the expression
|
answer = eval(expression) # Safe here as we control the expression
|
||||||
assert str(answer) == item["answer"]
|
assert str(answer) == item["answer"]
|
||||||
|
|
@ -60,10 +54,10 @@ def test_chain_sum_number_ranges():
|
||||||
min_digits=3, # Should generate numbers >= 100
|
min_digits=3, # Should generate numbers >= 100
|
||||||
max_digits=3, # Should generate numbers <= 999
|
max_digits=3, # Should generate numbers <= 999
|
||||||
size=50,
|
size=50,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
|
|
@ -74,16 +68,8 @@ def test_chain_sum_number_ranges():
|
||||||
else:
|
else:
|
||||||
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||||
|
|
||||||
|
|
||||||
# Test 1-digit numbers
|
# Test 1-digit numbers
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
min_digits=1,
|
|
||||||
max_digits=1,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
|
|
@ -95,58 +81,48 @@ def test_chain_sum_number_ranges():
|
||||||
else:
|
else:
|
||||||
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||||
|
|
||||||
|
|
||||||
def test_chain_sum_negation():
|
def test_chain_sum_negation():
|
||||||
"""Test that allow_negation controls number ranges"""
|
"""Test that allow_negation controls number ranges"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(
|
||||||
min_terms=2,
|
min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
|
||||||
max_terms=2,
|
|
||||||
min_digits=2,
|
|
||||||
max_digits=2,
|
|
||||||
size=100,
|
|
||||||
seed=42,
|
|
||||||
allow_negation=True
|
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
# Track if we see both positive and negative numbers
|
# Track if we see both positive and negative numbers
|
||||||
has_positive = False
|
has_positive = False
|
||||||
has_negative = False
|
has_negative = False
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith('-') and n[1:].isdigit())]
|
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())]
|
||||||
|
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
if num > 0:
|
if num > 0:
|
||||||
has_positive = True
|
has_positive = True
|
||||||
if num < 0:
|
if num < 0:
|
||||||
has_negative = True
|
has_negative = True
|
||||||
|
|
||||||
# With enough samples and allow_negation=True, we should see both positive and negative numbers
|
# With enough samples and allow_negation=True, we should see both positive and negative numbers
|
||||||
assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True"
|
assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True"
|
||||||
|
|
||||||
|
|
||||||
def test_chain_sum_iteration():
|
def test_chain_sum_iteration():
|
||||||
"""Test that iteration respects dataset size"""
|
"""Test that iteration respects dataset size"""
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
min_terms=2,
|
|
||||||
max_terms=2,
|
|
||||||
size=5, # Small size for testing
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSum(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
# Test list conversion
|
# Test list conversion
|
||||||
items = list(dataset)
|
items = list(dataset)
|
||||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
# Test multiple iterations
|
# Test multiple iterations
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
|
||||||
from math import gcd
|
from math import gcd
|
||||||
from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
|
|
||||||
|
|
||||||
def test_fraction_config_validation():
|
def test_fraction_config_validation():
|
||||||
|
|
@ -8,15 +10,15 @@ def test_fraction_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = FractionSimplificationConfig(min_value=0) # Should be positive
|
config = FractionSimplificationConfig(min_value=0) # Should be positive
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min
|
config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = FractionSimplificationConfig(min_factor=0) # Should be >= 1
|
config = FractionSimplificationConfig(min_factor=0) # Should be >= 1
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min
|
config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -27,30 +29,23 @@ def test_fraction_deterministic():
|
||||||
config = FractionSimplificationConfig(seed=42, size=10)
|
config = FractionSimplificationConfig(seed=42, size=10)
|
||||||
dataset1 = FractionSimplificationDataset(config)
|
dataset1 = FractionSimplificationDataset(config)
|
||||||
dataset2 = FractionSimplificationDataset(config)
|
dataset2 = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
def test_fraction_items():
|
def test_fraction_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=1, max_value=20, min_factor=2, max_factor=5, size=50, seed=42)
|
||||||
min_value=1,
|
|
||||||
max_value=20,
|
|
||||||
min_factor=2,
|
|
||||||
max_factor=5,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
assert isinstance(item, dict)
|
assert isinstance(item, dict)
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Verify the metadata contains all expected fields
|
# Verify the metadata contains all expected fields
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
assert "numerator" in metadata
|
assert "numerator" in metadata
|
||||||
|
|
@ -58,45 +53,38 @@ def test_fraction_items():
|
||||||
assert "simplified_numerator" in metadata
|
assert "simplified_numerator" in metadata
|
||||||
assert "simplified_denominator" in metadata
|
assert "simplified_denominator" in metadata
|
||||||
assert "reduction_factor" in metadata
|
assert "reduction_factor" in metadata
|
||||||
|
|
||||||
# Verify the numbers are within configured range
|
# Verify the numbers are within configured range
|
||||||
assert config.min_value <= metadata["simplified_numerator"] <= config.max_value
|
assert config.min_value <= metadata["simplified_numerator"] <= config.max_value
|
||||||
assert config.min_value <= metadata["simplified_denominator"] <= config.max_value
|
assert config.min_value <= metadata["simplified_denominator"] <= config.max_value
|
||||||
|
|
||||||
# Verify the reduction is correct
|
# Verify the reduction is correct
|
||||||
num = metadata["numerator"]
|
num = metadata["numerator"]
|
||||||
den = metadata["denominator"]
|
den = metadata["denominator"]
|
||||||
simple_num = metadata["simplified_numerator"]
|
simple_num = metadata["simplified_numerator"]
|
||||||
simple_den = metadata["simplified_denominator"]
|
simple_den = metadata["simplified_denominator"]
|
||||||
factor = metadata["reduction_factor"]
|
factor = metadata["reduction_factor"]
|
||||||
|
|
||||||
assert num == simple_num * factor
|
assert num == simple_num * factor
|
||||||
assert den == simple_den * factor
|
assert den == simple_den * factor
|
||||||
|
|
||||||
# Verify the simplified fraction is actually in lowest terms
|
# Verify the simplified fraction is actually in lowest terms
|
||||||
assert gcd(simple_num, simple_den) == 1
|
assert gcd(simple_num, simple_den) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_fraction_ranges():
|
def test_fraction_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=5, max_value=15, min_factor=3, max_factor=4, size=20, seed=42)
|
||||||
min_value=5,
|
|
||||||
max_value=15,
|
|
||||||
min_factor=3,
|
|
||||||
max_factor=4,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
factor = metadata["reduction_factor"]
|
factor = metadata["reduction_factor"]
|
||||||
|
|
||||||
# Check factor is within bounds
|
# Check factor is within bounds
|
||||||
assert 3 <= factor <= 4
|
assert 3 <= factor <= 4
|
||||||
|
|
||||||
# Check simplified values are within bounds
|
# Check simplified values are within bounds
|
||||||
assert 5 <= metadata["simplified_numerator"] <= 15
|
assert 5 <= metadata["simplified_numerator"] <= 15
|
||||||
assert 5 <= metadata["simplified_denominator"] <= 15
|
assert 5 <= metadata["simplified_denominator"] <= 15
|
||||||
|
|
@ -106,17 +94,17 @@ def test_fraction_iteration():
|
||||||
"""Test that iteration works correctly"""
|
"""Test that iteration works correctly"""
|
||||||
config = FractionSimplificationConfig(size=5, seed=42)
|
config = FractionSimplificationConfig(size=5, seed=42)
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test list conversion
|
# Test list conversion
|
||||||
items = list(dataset)
|
items = list(dataset)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test multiple iterations yield same results
|
# Test multiple iterations yield same results
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
|
|
@ -125,24 +113,19 @@ def test_fraction_iteration():
|
||||||
|
|
||||||
def test_fraction_numerator_smaller():
|
def test_fraction_numerator_smaller():
|
||||||
"""Test that numerators are always smaller than denominators"""
|
"""Test that numerators are always smaller than denominators"""
|
||||||
config = FractionSimplificationConfig(
|
config = FractionSimplificationConfig(min_value=1, max_value=100, min_factor=2, max_factor=5, size=50, seed=42)
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
min_factor=2,
|
|
||||||
max_factor=5,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = FractionSimplificationDataset(config)
|
dataset = FractionSimplificationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
|
|
||||||
# Check original fraction
|
# Check original fraction
|
||||||
assert metadata["numerator"] <= metadata["denominator"], \
|
assert (
|
||||||
f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
|
metadata["numerator"] <= metadata["denominator"]
|
||||||
|
), f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
|
||||||
|
|
||||||
# Check simplified fraction
|
# Check simplified fraction
|
||||||
assert metadata["simplified_numerator"] <= metadata["simplified_denominator"], \
|
assert (
|
||||||
f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
metadata["simplified_numerator"] <= metadata["simplified_denominator"]
|
||||||
|
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import pytest
|
|
||||||
from math import gcd
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from reasoning_gym.arithmetic import GCDDataset, GCDConfig
|
from math import gcd
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
|
||||||
|
|
||||||
|
|
||||||
def test_gcd_config_validation():
|
def test_gcd_config_validation():
|
||||||
|
|
@ -9,15 +11,15 @@ def test_gcd_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = GCDConfig(min_numbers=1) # Should be >= 2
|
config = GCDConfig(min_numbers=1) # Should be >= 2
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = GCDConfig(min_numbers=3, max_numbers=2) # max should be >= min
|
config = GCDConfig(min_numbers=3, max_numbers=2) # max should be >= min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = GCDConfig(min_value=0) # Should be positive
|
config = GCDConfig(min_value=0) # Should be positive
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = GCDConfig(min_value=100, max_value=50) # max should be > min
|
config = GCDConfig(min_value=100, max_value=50) # max should be > min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -28,40 +30,33 @@ def test_gcd_deterministic():
|
||||||
config = GCDConfig(seed=42, size=10)
|
config = GCDConfig(seed=42, size=10)
|
||||||
dataset1 = GCDDataset(config)
|
dataset1 = GCDDataset(config)
|
||||||
dataset2 = GCDDataset(config)
|
dataset2 = GCDDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
def test_gcd_items():
|
def test_gcd_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=4, min_value=1, max_value=100, size=50, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=4,
|
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
assert isinstance(item, dict)
|
assert isinstance(item, dict)
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Verify the numbers and result are in metadata
|
# Verify the numbers and result are in metadata
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
assert "numbers" in metadata
|
assert "numbers" in metadata
|
||||||
assert "result" in metadata
|
assert "result" in metadata
|
||||||
|
|
||||||
# Verify the numbers are within configured range
|
# Verify the numbers are within configured range
|
||||||
numbers = metadata["numbers"]
|
numbers = metadata["numbers"]
|
||||||
assert all(config.min_value <= n <= config.max_value for n in numbers)
|
assert all(config.min_value <= n <= config.max_value for n in numbers)
|
||||||
assert config.min_numbers <= len(numbers) <= config.max_numbers
|
assert config.min_numbers <= len(numbers) <= config.max_numbers
|
||||||
|
|
||||||
# Verify the GCD calculation is correct
|
# Verify the GCD calculation is correct
|
||||||
result = metadata["result"]
|
result = metadata["result"]
|
||||||
assert str(result) == item["answer"]
|
assert str(result) == item["answer"]
|
||||||
|
|
@ -70,16 +65,9 @@ def test_gcd_items():
|
||||||
|
|
||||||
def test_gcd_number_ranges():
|
def test_gcd_number_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=2, min_value=50, max_value=100, size=20, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=50,
|
|
||||||
max_value=100,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
numbers = item["metadata"]["numbers"]
|
numbers = item["metadata"]["numbers"]
|
||||||
|
|
@ -90,17 +78,17 @@ def test_gcd_iteration():
|
||||||
"""Test that iteration works correctly"""
|
"""Test that iteration works correctly"""
|
||||||
config = GCDConfig(size=5, seed=42)
|
config = GCDConfig(size=5, seed=42)
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test list conversion
|
# Test list conversion
|
||||||
items = list(dataset)
|
items = list(dataset)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test multiple iterations yield same results
|
# Test multiple iterations yield same results
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
|
|
@ -109,20 +97,13 @@ def test_gcd_iteration():
|
||||||
|
|
||||||
def test_gcd_special_cases():
|
def test_gcd_special_cases():
|
||||||
"""Test some special GCD cases"""
|
"""Test some special GCD cases"""
|
||||||
config = GCDConfig(
|
config = GCDConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=100, size=100, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=1,
|
|
||||||
max_value=100,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = GCDDataset(config)
|
dataset = GCDDataset(config)
|
||||||
|
|
||||||
# Track if we see some interesting GCD cases
|
# Track if we see some interesting GCD cases
|
||||||
seen_gcd_1 = False # Coprime numbers
|
seen_gcd_1 = False # Coprime numbers
|
||||||
seen_large_gcd = False # GCD > 1
|
seen_large_gcd = False # GCD > 1
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
result = int(item["answer"])
|
result = int(item["answer"])
|
||||||
|
|
@ -130,7 +111,7 @@ def test_gcd_special_cases():
|
||||||
seen_gcd_1 = True
|
seen_gcd_1 = True
|
||||||
if result > 1:
|
if result > 1:
|
||||||
seen_large_gcd = True
|
seen_large_gcd = True
|
||||||
|
|
||||||
# With enough samples, we should see both coprime and non-coprime numbers
|
# With enough samples, we should see both coprime and non-coprime numbers
|
||||||
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
|
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
|
||||||
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"
|
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import pytest
|
|
||||||
from math import lcm
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from reasoning_gym.arithmetic import LCMDataset, LCMConfig
|
from math import lcm
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import LCMConfig, LCMDataset
|
||||||
|
|
||||||
|
|
||||||
def test_lcm_config_validation():
|
def test_lcm_config_validation():
|
||||||
|
|
@ -9,15 +11,15 @@ def test_lcm_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = LCMConfig(min_numbers=1) # Should be >= 2
|
config = LCMConfig(min_numbers=1) # Should be >= 2
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min
|
config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = LCMConfig(min_value=0) # Should be positive
|
config = LCMConfig(min_value=0) # Should be positive
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = LCMConfig(min_value=100, max_value=50) # max should be > min
|
config = LCMConfig(min_value=100, max_value=50) # max should be > min
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -28,7 +30,7 @@ def test_lcm_deterministic():
|
||||||
config = LCMConfig(seed=42, size=10)
|
config = LCMConfig(seed=42, size=10)
|
||||||
dataset1 = LCMDataset(config)
|
dataset1 = LCMDataset(config)
|
||||||
dataset2 = LCMDataset(config)
|
dataset2 = LCMDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
@ -36,32 +38,27 @@ def test_lcm_deterministic():
|
||||||
def test_lcm_items():
|
def test_lcm_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(
|
||||||
min_numbers=2,
|
min_numbers=2, max_numbers=4, min_value=1, max_value=20, size=50, seed=42 # Keep small for testing
|
||||||
max_numbers=4,
|
|
||||||
min_value=1,
|
|
||||||
max_value=20, # Keep small for testing
|
|
||||||
size=50,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
assert isinstance(item, dict)
|
assert isinstance(item, dict)
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Verify the numbers and result are in metadata
|
# Verify the numbers and result are in metadata
|
||||||
metadata = item["metadata"]
|
metadata = item["metadata"]
|
||||||
assert "numbers" in metadata
|
assert "numbers" in metadata
|
||||||
assert "result" in metadata
|
assert "result" in metadata
|
||||||
|
|
||||||
# Verify the numbers are within configured range
|
# Verify the numbers are within configured range
|
||||||
numbers = metadata["numbers"]
|
numbers = metadata["numbers"]
|
||||||
assert all(config.min_value <= n <= config.max_value for n in numbers)
|
assert all(config.min_value <= n <= config.max_value for n in numbers)
|
||||||
assert config.min_numbers <= len(numbers) <= config.max_numbers
|
assert config.min_numbers <= len(numbers) <= config.max_numbers
|
||||||
|
|
||||||
# Verify the LCM calculation is correct
|
# Verify the LCM calculation is correct
|
||||||
result = metadata["result"]
|
result = metadata["result"]
|
||||||
assert str(result) == item["answer"]
|
assert str(result) == item["answer"]
|
||||||
|
|
@ -70,16 +67,9 @@ def test_lcm_items():
|
||||||
|
|
||||||
def test_lcm_number_ranges():
|
def test_lcm_number_ranges():
|
||||||
"""Test that generated numbers respect value constraints"""
|
"""Test that generated numbers respect value constraints"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(min_numbers=2, max_numbers=2, min_value=5, max_value=15, size=20, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=5,
|
|
||||||
max_value=15,
|
|
||||||
size=20,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
numbers = item["metadata"]["numbers"]
|
numbers = item["metadata"]["numbers"]
|
||||||
|
|
@ -90,17 +80,17 @@ def test_lcm_iteration():
|
||||||
"""Test that iteration works correctly"""
|
"""Test that iteration works correctly"""
|
||||||
config = LCMConfig(size=5, seed=42)
|
config = LCMConfig(size=5, seed=42)
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test list conversion
|
# Test list conversion
|
||||||
items = list(dataset)
|
items = list(dataset)
|
||||||
assert len(items) == config.size
|
assert len(items) == config.size
|
||||||
|
|
||||||
# Test multiple iterations yield same results
|
# Test multiple iterations yield same results
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
|
|
@ -109,31 +99,24 @@ def test_lcm_iteration():
|
||||||
|
|
||||||
def test_lcm_special_cases():
|
def test_lcm_special_cases():
|
||||||
"""Test some special LCM cases"""
|
"""Test some special LCM cases"""
|
||||||
config = LCMConfig(
|
config = LCMConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=20, size=100, seed=42)
|
||||||
min_numbers=2,
|
|
||||||
max_numbers=2,
|
|
||||||
min_value=1,
|
|
||||||
max_value=20,
|
|
||||||
size=100,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LCMDataset(config)
|
dataset = LCMDataset(config)
|
||||||
|
|
||||||
# Track if we see some interesting LCM cases
|
# Track if we see some interesting LCM cases
|
||||||
seen_equal_to_product = False # When numbers are coprime
|
seen_equal_to_product = False # When numbers are coprime
|
||||||
seen_less_than_product = False # When numbers share factors
|
seen_less_than_product = False # When numbers share factors
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
numbers = item["metadata"]["numbers"]
|
numbers = item["metadata"]["numbers"]
|
||||||
result = int(item["answer"])
|
result = int(item["answer"])
|
||||||
product = reduce(lambda x, y: x * y, numbers)
|
product = reduce(lambda x, y: x * y, numbers)
|
||||||
|
|
||||||
if result == product:
|
if result == product:
|
||||||
seen_equal_to_product = True
|
seen_equal_to_product = True
|
||||||
if result < product:
|
if result < product:
|
||||||
seen_less_than_product = True
|
seen_less_than_product = True
|
||||||
|
|
||||||
# With enough samples, we should see both cases
|
# With enough samples, we should see both cases
|
||||||
assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)"
|
assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)"
|
||||||
assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)"
|
assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)"
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
"""Tests for leg counting task generation"""
|
"""Tests for leg counting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.leg_counting import (
|
from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset
|
||||||
LegCountingConfig,
|
|
||||||
LegCountingDataset,
|
|
||||||
ANIMALS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_leg_counting_config_validation():
|
def test_leg_counting_config_validation():
|
||||||
|
|
@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic():
|
||||||
|
|
||||||
def test_leg_counting_dataset_items():
|
def test_leg_counting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LegCountingConfig(
|
config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42)
|
||||||
min_animals=2,
|
|
||||||
max_animals=4,
|
|
||||||
max_instances=2,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LegCountingDataset(config)
|
dataset = LegCountingDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -51,19 +42,19 @@ def test_leg_counting_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "animals" in item["metadata"]
|
assert "animals" in item["metadata"]
|
||||||
assert "total_legs" in item["metadata"]
|
assert "total_legs" in item["metadata"]
|
||||||
|
|
||||||
# Verify animal count constraints
|
# Verify animal count constraints
|
||||||
animals = item["metadata"]["animals"]
|
animals = item["metadata"]["animals"]
|
||||||
assert len(animals) >= config.min_animals
|
assert len(animals) >= config.min_animals
|
||||||
assert len(animals) <= config.max_animals
|
assert len(animals) <= config.max_animals
|
||||||
|
|
||||||
# Verify instance count constraints
|
# Verify instance count constraints
|
||||||
assert all(1 <= count <= config.max_instances for count in animals.values())
|
assert all(1 <= count <= config.max_instances for count in animals.values())
|
||||||
|
|
||||||
# Verify leg counting is correct
|
# Verify leg counting is correct
|
||||||
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
||||||
assert str(total_legs) == item["answer"]
|
assert str(total_legs) == item["answer"]
|
||||||
|
|
@ -86,7 +77,7 @@ def test_leg_counting_animal_validation():
|
||||||
"""Test that all animals have valid leg counts"""
|
"""Test that all animals have valid leg counts"""
|
||||||
# Verify all animals have non-negative leg counts
|
# Verify all animals have non-negative leg counts
|
||||||
assert all(legs >= 0 for legs in ANIMALS.values())
|
assert all(legs >= 0 for legs in ANIMALS.values())
|
||||||
|
|
||||||
# Verify common animals have expected leg counts
|
# Verify common animals have expected leg counts
|
||||||
assert ANIMALS["spider"] == 8
|
assert ANIMALS["spider"] == 8
|
||||||
assert ANIMALS["insect"] == 6
|
assert ANIMALS["insect"] == 6
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for letter counting task generation"""
|
"""Tests for letter counting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.letter_counting import (
|
from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||||
LetterCountingConfig,
|
|
||||||
LetterCountingDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_letter_counting_config_validation():
|
def test_letter_counting_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_letter_counting_dataset_deterministic():
|
||||||
|
|
||||||
def test_letter_counting_dataset_items():
|
def test_letter_counting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = LetterCountingConfig(
|
config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||||
min_words=3,
|
|
||||||
max_words=6,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = LetterCountingDataset(config)
|
dataset = LetterCountingDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -45,17 +38,17 @@ def test_letter_counting_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "span_length" in item["metadata"]
|
assert "span_length" in item["metadata"]
|
||||||
assert "target_letter" in item["metadata"]
|
assert "target_letter" in item["metadata"]
|
||||||
assert "span" in item["metadata"]
|
assert "span" in item["metadata"]
|
||||||
|
|
||||||
# Verify span length constraints
|
# Verify span length constraints
|
||||||
span = item["metadata"]["span"]
|
span = item["metadata"]["span"]
|
||||||
assert len(span) >= config.min_words
|
assert len(span) >= config.min_words
|
||||||
assert len(span) <= config.max_words
|
assert len(span) <= config.max_words
|
||||||
|
|
||||||
# Verify letter counting
|
# Verify letter counting
|
||||||
target_letter = item["metadata"]["target_letter"]
|
target_letter = item["metadata"]["target_letter"]
|
||||||
count = sum(word.lower().count(target_letter) for word in span)
|
count = sum(word.lower().count(target_letter) for word in span)
|
||||||
|
|
@ -78,7 +71,7 @@ def test_letter_counting_text_preprocessing():
|
||||||
"""Test that text preprocessing handles edge cases"""
|
"""Test that text preprocessing handles edge cases"""
|
||||||
config = LetterCountingConfig(size=1, seed=42)
|
config = LetterCountingConfig(size=1, seed=42)
|
||||||
dataset = LetterCountingDataset(config)
|
dataset = LetterCountingDataset(config)
|
||||||
|
|
||||||
# Verify words were extracted from text
|
# Verify words were extracted from text
|
||||||
assert len(dataset.words) > 0
|
assert len(dataset.words) > 0
|
||||||
# Verify words contain only word characters
|
# Verify words contain only word characters
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for mini sudoku puzzle generation"""
|
"""Tests for mini sudoku puzzle generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.mini_sudoku import (
|
from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
|
||||||
MiniSudokuConfig,
|
|
||||||
MiniSudokuDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mini_sudoku_config_validation():
|
def test_mini_sudoku_config_validation():
|
||||||
|
|
@ -34,12 +32,7 @@ def test_mini_sudoku_dataset_deterministic():
|
||||||
|
|
||||||
def test_mini_sudoku_dataset_items():
|
def test_mini_sudoku_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = MiniSudokuConfig(
|
config = MiniSudokuConfig(min_empty=8, max_empty=12, size=10, seed=42)
|
||||||
min_empty=8,
|
|
||||||
max_empty=12,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = MiniSudokuDataset(config)
|
dataset = MiniSudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -49,30 +42,30 @@ def test_mini_sudoku_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "puzzle" in item["metadata"]
|
assert "puzzle" in item["metadata"]
|
||||||
assert "solution" in item["metadata"]
|
assert "solution" in item["metadata"]
|
||||||
assert "num_empty" in item["metadata"]
|
assert "num_empty" in item["metadata"]
|
||||||
|
|
||||||
puzzle = item["metadata"]["puzzle"]
|
puzzle = item["metadata"]["puzzle"]
|
||||||
solution = item["metadata"]["solution"]
|
solution = item["metadata"]["solution"]
|
||||||
num_empty = item["metadata"]["num_empty"]
|
num_empty = item["metadata"]["num_empty"]
|
||||||
|
|
||||||
# Verify board dimensions
|
# Verify board dimensions
|
||||||
assert len(puzzle) == 4
|
assert len(puzzle) == 4
|
||||||
assert all(len(row) == 4 for row in puzzle)
|
assert all(len(row) == 4 for row in puzzle)
|
||||||
assert len(solution) == 4
|
assert len(solution) == 4
|
||||||
assert all(len(row) == 4 for row in solution)
|
assert all(len(row) == 4 for row in solution)
|
||||||
|
|
||||||
# Verify empty cell count
|
# Verify empty cell count
|
||||||
empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
|
empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
|
||||||
assert config.min_empty <= empty_count <= config.max_empty
|
assert config.min_empty <= empty_count <= config.max_empty
|
||||||
assert empty_count == num_empty
|
assert empty_count == num_empty
|
||||||
|
|
||||||
# Verify solution validity
|
# Verify solution validity
|
||||||
assert is_valid_solution(solution)
|
assert is_valid_solution(solution)
|
||||||
|
|
||||||
# Verify puzzle matches solution where filled
|
# Verify puzzle matches solution where filled
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
for j in range(4):
|
for j in range(4):
|
||||||
|
|
@ -94,14 +87,9 @@ def test_mini_sudoku_dataset_iteration():
|
||||||
|
|
||||||
def test_mini_sudoku_board_generation():
|
def test_mini_sudoku_board_generation():
|
||||||
"""Test that generated boards are valid"""
|
"""Test that generated boards are valid"""
|
||||||
config = MiniSudokuConfig(
|
config = MiniSudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
|
||||||
min_empty=0, # Force complete board
|
|
||||||
max_empty=0,
|
|
||||||
size=5,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = MiniSudokuDataset(config)
|
dataset = MiniSudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
board = item["metadata"]["solution"]
|
board = item["metadata"]["solution"]
|
||||||
|
|
@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool:
|
||||||
for row in board:
|
for row in board:
|
||||||
if set(row) != set(range(1, 5)):
|
if set(row) != set(range(1, 5)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check columns
|
# Check columns
|
||||||
for j in range(4):
|
for j in range(4):
|
||||||
column = [board[i][j] for i in range(4)]
|
column = [board[i][j] for i in range(4)]
|
||||||
if set(column) != set(range(1, 5)):
|
if set(column) != set(range(1, 5)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check 2x2 boxes
|
# Check 2x2 boxes
|
||||||
for box_i in range(2):
|
for box_i in range(2):
|
||||||
for box_j in range(2):
|
for box_j in range(2):
|
||||||
box = []
|
box = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
box.append(board[box_i*2 + i][box_j*2 + j])
|
box.append(board[box_i * 2 + i][box_j * 2 + j])
|
||||||
if set(box) != set(range(1, 5)):
|
if set(box) != set(range(1, 5)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for number filtering task generation"""
|
"""Tests for number filtering task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.number_filtering import (
|
from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||||
NumberFilteringConfig,
|
|
||||||
NumberFilteringDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_number_filtering_config_validation():
|
def test_number_filtering_config_validation():
|
||||||
|
|
@ -16,11 +14,11 @@ def test_number_filtering_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberFilteringConfig(min_numbers=10, max_numbers=5)
|
config = NumberFilteringConfig(min_numbers=10, max_numbers=5)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberFilteringConfig(min_decimals=-1)
|
config = NumberFilteringConfig(min_decimals=-1)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberFilteringConfig(min_value=100, max_value=0)
|
config = NumberFilteringConfig(min_value=100, max_value=0)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -39,14 +37,7 @@ def test_number_filtering_dataset_deterministic():
|
||||||
def test_number_filtering_dataset_items():
|
def test_number_filtering_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = NumberFilteringConfig(
|
config = NumberFilteringConfig(
|
||||||
min_numbers=3,
|
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||||
max_numbers=6,
|
|
||||||
min_decimals=1,
|
|
||||||
max_decimals=3,
|
|
||||||
min_value=-10.0,
|
|
||||||
max_value=10.0,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = NumberFilteringDataset(config)
|
dataset = NumberFilteringDataset(config)
|
||||||
|
|
||||||
|
|
@ -57,34 +48,34 @@ def test_number_filtering_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "original_numbers" in item["metadata"]
|
assert "original_numbers" in item["metadata"]
|
||||||
assert "filter_value" in item["metadata"]
|
assert "filter_value" in item["metadata"]
|
||||||
assert "operation" in item["metadata"]
|
assert "operation" in item["metadata"]
|
||||||
assert "result" in item["metadata"]
|
assert "result" in item["metadata"]
|
||||||
|
|
||||||
# Verify number count constraints
|
# Verify number count constraints
|
||||||
numbers = item["metadata"]["original_numbers"]
|
numbers = item["metadata"]["original_numbers"]
|
||||||
assert len(numbers) >= config.min_numbers
|
assert len(numbers) >= config.min_numbers
|
||||||
assert len(numbers) <= config.max_numbers
|
assert len(numbers) <= config.max_numbers
|
||||||
|
|
||||||
# Verify decimal places
|
# Verify decimal places
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
decimal_places = len(num.split('.')[-1]) if '.' in num else 0
|
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||||
assert decimal_places >= config.min_decimals
|
assert decimal_places >= config.min_decimals
|
||||||
assert decimal_places <= config.max_decimals
|
assert decimal_places <= config.max_decimals
|
||||||
|
|
||||||
# Verify value range
|
# Verify value range
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
value = float(num)
|
value = float(num)
|
||||||
assert config.min_value <= value <= config.max_value
|
assert config.min_value <= value <= config.max_value
|
||||||
|
|
||||||
# Verify filtering operation
|
# Verify filtering operation
|
||||||
operation = item["metadata"]["operation"]
|
operation = item["metadata"]["operation"]
|
||||||
filter_value = float(item["metadata"]["filter_value"])
|
filter_value = float(item["metadata"]["filter_value"])
|
||||||
result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else []
|
result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else []
|
||||||
|
|
||||||
if operation == "keep_larger":
|
if operation == "keep_larger":
|
||||||
assert all(x > filter_value for x in result)
|
assert all(x > filter_value for x in result)
|
||||||
elif operation == "keep_smaller":
|
elif operation == "keep_smaller":
|
||||||
|
|
@ -117,11 +108,11 @@ def test_number_filtering_precision():
|
||||||
min_value=0.0,
|
min_value=0.0,
|
||||||
max_value=1.0,
|
max_value=1.0,
|
||||||
size=1,
|
size=1,
|
||||||
seed=42
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = NumberFilteringDataset(config)
|
dataset = NumberFilteringDataset(config)
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
# Check that string representations maintain precision
|
# Check that string representations maintain precision
|
||||||
for num in item["metadata"]["original_numbers"]:
|
for num in item["metadata"]["original_numbers"]:
|
||||||
assert len(num.split('.')[-1]) == 2
|
assert len(num.split(".")[-1]) == 2
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for number sorting task generation"""
|
"""Tests for number sorting task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.number_sorting import (
|
from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||||
NumberSortingConfig,
|
|
||||||
NumberSortingDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_number_sorting_config_validation():
|
def test_number_sorting_config_validation():
|
||||||
|
|
@ -16,11 +14,11 @@ def test_number_sorting_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberSortingConfig(min_numbers=10, max_numbers=5)
|
config = NumberSortingConfig(min_numbers=10, max_numbers=5)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberSortingConfig(min_decimals=-1)
|
config = NumberSortingConfig(min_decimals=-1)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = NumberSortingConfig(min_value=100, max_value=0)
|
config = NumberSortingConfig(min_value=100, max_value=0)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -39,14 +37,7 @@ def test_number_sorting_dataset_deterministic():
|
||||||
def test_number_sorting_dataset_items():
|
def test_number_sorting_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = NumberSortingConfig(
|
config = NumberSortingConfig(
|
||||||
min_numbers=3,
|
min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
|
||||||
max_numbers=6,
|
|
||||||
min_decimals=1,
|
|
||||||
max_decimals=3,
|
|
||||||
min_value=-10.0,
|
|
||||||
max_value=10.0,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
)
|
||||||
dataset = NumberSortingDataset(config)
|
dataset = NumberSortingDataset(config)
|
||||||
|
|
||||||
|
|
@ -57,28 +48,28 @@ def test_number_sorting_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "original_numbers" in item["metadata"]
|
assert "original_numbers" in item["metadata"]
|
||||||
assert "direction" in item["metadata"]
|
assert "direction" in item["metadata"]
|
||||||
assert "sorted_numbers" in item["metadata"]
|
assert "sorted_numbers" in item["metadata"]
|
||||||
|
|
||||||
# Verify number count constraints
|
# Verify number count constraints
|
||||||
numbers = item["metadata"]["original_numbers"]
|
numbers = item["metadata"]["original_numbers"]
|
||||||
assert len(numbers) >= config.min_numbers
|
assert len(numbers) >= config.min_numbers
|
||||||
assert len(numbers) <= config.max_numbers
|
assert len(numbers) <= config.max_numbers
|
||||||
|
|
||||||
# Verify decimal places
|
# Verify decimal places
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
decimal_places = len(num.split('.')[-1]) if '.' in num else 0
|
decimal_places = len(num.split(".")[-1]) if "." in num else 0
|
||||||
assert decimal_places >= config.min_decimals
|
assert decimal_places >= config.min_decimals
|
||||||
assert decimal_places <= config.max_decimals
|
assert decimal_places <= config.max_decimals
|
||||||
|
|
||||||
# Verify value range
|
# Verify value range
|
||||||
for num in numbers:
|
for num in numbers:
|
||||||
value = float(num)
|
value = float(num)
|
||||||
assert config.min_value <= value <= config.max_value
|
assert config.min_value <= value <= config.max_value
|
||||||
|
|
||||||
# Verify sorting
|
# Verify sorting
|
||||||
direction = item["metadata"]["direction"]
|
direction = item["metadata"]["direction"]
|
||||||
sorted_numbers = [float(x) for x in eval(item["answer"])]
|
sorted_numbers = [float(x) for x in eval(item["answer"])]
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for prime factorization task generation"""
|
"""Tests for prime factorization task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.prime_factorization import (
|
from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
PrimeFactorizationConfig,
|
|
||||||
PrimeFactorizationDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prime_factorization_config_validation():
|
def test_prime_factorization_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_prime_factorization_dataset_deterministic():
|
||||||
|
|
||||||
def test_prime_factorization_dataset_items():
|
def test_prime_factorization_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = PrimeFactorizationConfig(
|
config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42)
|
||||||
min_value=2,
|
|
||||||
max_value=100,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = PrimeFactorizationDataset(config)
|
dataset = PrimeFactorizationDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -45,26 +38,26 @@ def test_prime_factorization_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "number" in item["metadata"]
|
assert "number" in item["metadata"]
|
||||||
assert "factors" in item["metadata"]
|
assert "factors" in item["metadata"]
|
||||||
|
|
||||||
# Verify value range
|
# Verify value range
|
||||||
number = item["metadata"]["number"]
|
number = item["metadata"]["number"]
|
||||||
assert config.min_value <= number <= config.max_value
|
assert config.min_value <= number <= config.max_value
|
||||||
|
|
||||||
# Verify factorization is correct
|
# Verify factorization is correct
|
||||||
factors = item["metadata"]["factors"]
|
factors = item["metadata"]["factors"]
|
||||||
product = 1
|
product = 1
|
||||||
for factor in factors:
|
for factor in factors:
|
||||||
product *= factor
|
product *= factor
|
||||||
assert product == number
|
assert product == number
|
||||||
|
|
||||||
# Verify factors are prime
|
# Verify factors are prime
|
||||||
for factor in factors:
|
for factor in factors:
|
||||||
assert is_prime(factor), f"{factor} is not prime"
|
assert is_prime(factor), f"{factor} is not prime"
|
||||||
|
|
||||||
# Verify answer format
|
# Verify answer format
|
||||||
assert item["answer"] == " × ".join(map(str, factors))
|
assert item["answer"] == " × ".join(map(str, factors))
|
||||||
|
|
||||||
|
|
@ -83,15 +76,10 @@ def test_prime_factorization_dataset_iteration():
|
||||||
|
|
||||||
def test_prime_factorization_known_values():
|
def test_prime_factorization_known_values():
|
||||||
"""Test factorization of known values"""
|
"""Test factorization of known values"""
|
||||||
config = PrimeFactorizationConfig(
|
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
|
||||||
min_value=12,
|
|
||||||
max_value=12, # Force specific number
|
|
||||||
size=1,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = PrimeFactorizationDataset(config)
|
dataset = PrimeFactorizationDataset(config)
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
assert item["metadata"]["number"] == 12
|
assert item["metadata"]["number"] == 12
|
||||||
assert item["metadata"]["factors"] == [2, 2, 3]
|
assert item["metadata"]["factors"] == [2, 2, 3]
|
||||||
assert item["answer"] == "2 × 2 × 3"
|
assert item["answer"] == "2 × 2 × 3"
|
||||||
|
|
@ -101,7 +89,7 @@ def is_prime(n: int) -> bool:
|
||||||
"""Helper function to check if a number is prime"""
|
"""Helper function to check if a number is prime"""
|
||||||
if n < 2:
|
if n < 2:
|
||||||
return False
|
return False
|
||||||
for i in range(2, int(n ** 0.5) + 1):
|
for i in range(2, int(n**0.5) + 1):
|
||||||
if n % i == 0:
|
if n % i == 0:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,14 @@ def test_pattern_rule():
|
||||||
# Test simple addition
|
# Test simple addition
|
||||||
rule = PatternRule([Operation.ADD], [2])
|
rule = PatternRule([Operation.ADD], [2])
|
||||||
assert rule.apply([1, 3], 1) == 5
|
assert rule.apply([1, 3], 1) == 5
|
||||||
|
|
||||||
# Test composition
|
# Test composition
|
||||||
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
|
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
|
||||||
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
||||||
|
|
||||||
# Test rule composition
|
# Test rule composition
|
||||||
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
|
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
|
||||||
rule2 = PatternRule([Operation.ADD], [3]) # Add 3
|
rule2 = PatternRule([Operation.ADD], [3]) # Add 3
|
||||||
composed = PatternRule.compose([rule1, rule2])
|
composed = PatternRule.compose([rule1, rule2])
|
||||||
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for sudoku puzzle generation"""
|
"""Tests for sudoku puzzle generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.sudoku import (
|
from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset
|
||||||
SudokuConfig,
|
|
||||||
SudokuDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sudoku_config_validation():
|
def test_sudoku_config_validation():
|
||||||
|
|
@ -34,12 +32,7 @@ def test_sudoku_dataset_deterministic():
|
||||||
|
|
||||||
def test_sudoku_dataset_items():
|
def test_sudoku_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = SudokuConfig(
|
config = SudokuConfig(min_empty=30, max_empty=40, size=10, seed=42)
|
||||||
min_empty=30,
|
|
||||||
max_empty=40,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SudokuDataset(config)
|
dataset = SudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -49,30 +42,30 @@ def test_sudoku_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "puzzle" in item["metadata"]
|
assert "puzzle" in item["metadata"]
|
||||||
assert "solution" in item["metadata"]
|
assert "solution" in item["metadata"]
|
||||||
assert "num_empty" in item["metadata"]
|
assert "num_empty" in item["metadata"]
|
||||||
|
|
||||||
puzzle = item["metadata"]["puzzle"]
|
puzzle = item["metadata"]["puzzle"]
|
||||||
solution = item["metadata"]["solution"]
|
solution = item["metadata"]["solution"]
|
||||||
num_empty = item["metadata"]["num_empty"]
|
num_empty = item["metadata"]["num_empty"]
|
||||||
|
|
||||||
# Verify board dimensions
|
# Verify board dimensions
|
||||||
assert len(puzzle) == 9
|
assert len(puzzle) == 9
|
||||||
assert all(len(row) == 9 for row in puzzle)
|
assert all(len(row) == 9 for row in puzzle)
|
||||||
assert len(solution) == 9
|
assert len(solution) == 9
|
||||||
assert all(len(row) == 9 for row in solution)
|
assert all(len(row) == 9 for row in solution)
|
||||||
|
|
||||||
# Verify empty cell count
|
# Verify empty cell count
|
||||||
empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
|
empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
|
||||||
assert config.min_empty <= empty_count <= config.max_empty
|
assert config.min_empty <= empty_count <= config.max_empty
|
||||||
assert empty_count == num_empty
|
assert empty_count == num_empty
|
||||||
|
|
||||||
# Verify solution validity
|
# Verify solution validity
|
||||||
assert is_valid_solution(solution)
|
assert is_valid_solution(solution)
|
||||||
|
|
||||||
# Verify puzzle matches solution where filled
|
# Verify puzzle matches solution where filled
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
for j in range(9):
|
for j in range(9):
|
||||||
|
|
@ -94,14 +87,9 @@ def test_sudoku_dataset_iteration():
|
||||||
|
|
||||||
def test_sudoku_board_generation():
|
def test_sudoku_board_generation():
|
||||||
"""Test that generated boards are valid"""
|
"""Test that generated boards are valid"""
|
||||||
config = SudokuConfig(
|
config = SudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
|
||||||
min_empty=0, # Force complete board
|
|
||||||
max_empty=0,
|
|
||||||
size=5,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = SudokuDataset(config)
|
dataset = SudokuDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
board = item["metadata"]["solution"]
|
board = item["metadata"]["solution"]
|
||||||
|
|
@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool:
|
||||||
for row in board:
|
for row in board:
|
||||||
if set(row) != set(range(1, 10)):
|
if set(row) != set(range(1, 10)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check columns
|
# Check columns
|
||||||
for j in range(9):
|
for j in range(9):
|
||||||
column = [board[i][j] for i in range(9)]
|
column = [board[i][j] for i in range(9)]
|
||||||
if set(column) != set(range(1, 10)):
|
if set(column) != set(range(1, 10)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check 3x3 boxes
|
# Check 3x3 boxes
|
||||||
for box_i in range(3):
|
for box_i in range(3):
|
||||||
for box_j in range(3):
|
for box_j in range(3):
|
||||||
box = []
|
box = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
for j in range(3):
|
for j in range(3):
|
||||||
box.append(board[box_i*3 + i][box_j*3 + j])
|
box.append(board[box_i * 3 + i][box_j * 3 + j])
|
||||||
if set(box) != set(range(1, 10)):
|
if set(box) != set(range(1, 10)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Tests for word reversal task generation"""
|
"""Tests for word reversal task generation"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.word_reversal import (
|
from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset
|
||||||
WordReversalConfig,
|
|
||||||
WordReversalDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_word_reversal_config_validation():
|
def test_word_reversal_config_validation():
|
||||||
|
|
@ -30,12 +28,7 @@ def test_word_reversal_dataset_deterministic():
|
||||||
|
|
||||||
def test_word_reversal_dataset_items():
|
def test_word_reversal_dataset_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = WordReversalConfig(
|
config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42)
|
||||||
min_words=3,
|
|
||||||
max_words=6,
|
|
||||||
size=10,
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
dataset = WordReversalDataset(config)
|
dataset = WordReversalDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
|
|
@ -45,16 +38,16 @@ def test_word_reversal_dataset_items():
|
||||||
assert "question" in item
|
assert "question" in item
|
||||||
assert "answer" in item
|
assert "answer" in item
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert "num_words" in item["metadata"]
|
assert "num_words" in item["metadata"]
|
||||||
assert "words" in item["metadata"]
|
assert "words" in item["metadata"]
|
||||||
|
|
||||||
# Verify word count constraints
|
# Verify word count constraints
|
||||||
words = item["metadata"]["words"]
|
words = item["metadata"]["words"]
|
||||||
assert len(words) >= config.min_words
|
assert len(words) >= config.min_words
|
||||||
assert len(words) <= config.max_words
|
assert len(words) <= config.max_words
|
||||||
|
|
||||||
# Verify reversal is correct
|
# Verify reversal is correct
|
||||||
question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")]
|
question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")]
|
||||||
answer_words = item["answer"].split(", ")
|
answer_words = item["answer"].split(", ")
|
||||||
|
|
@ -77,7 +70,7 @@ def test_word_reversal_text_preprocessing():
|
||||||
"""Test that text preprocessing handles edge cases"""
|
"""Test that text preprocessing handles edge cases"""
|
||||||
config = WordReversalConfig(size=1, seed=42)
|
config = WordReversalConfig(size=1, seed=42)
|
||||||
dataset = WordReversalDataset(config)
|
dataset = WordReversalDataset(config)
|
||||||
|
|
||||||
# Verify words were extracted from text
|
# Verify words were extracted from text
|
||||||
assert len(dataset.words) > 0
|
assert len(dataset.words) > 0
|
||||||
# Verify words contain only alphanumeric characters
|
# Verify words contain only alphanumeric characters
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue