diff --git a/pyproject.toml b/pyproject.toml index 9b5c9252..71139bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "A library of procedural dataset generators for training reasoning models" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.9" dependencies = ["sympy>=1.13.1"] classifiers = [ "Programming Language :: Python :: 3", diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 16bcac1b..d2872cf6 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -15,7 +15,7 @@ from .number_sorting import NumberSortingConfig, NumberSortingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset -from .word_sorting import WordSortingConfig, WordSortingDataset, TextTransformation +from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index 4028c425..bc509b8c 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -8,9 +8,11 @@ from typing import List, Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset + @dataclass class SentenceReorderingConfig: """Configuration for sentence reordering task generation""" + min_words_in_sentence: int = 3 max_words_in_sentence: int = 20 seed: Optional[int] = None @@ -19,7 +21,12 @@ class SentenceReorderingConfig: def validate(self) -> None: """Validate configuration parameters""" assert self.min_words_in_sentence > 0, "min_words_in_sentence must be positive" - assert self.max_words_in_sentence >= self.min_words_in_sentence, "max_words_in_sentence must be >= min_words_in_sentence" + assert ( + self.max_words_in_sentence >= self.min_words_in_sentence + ), "max_words_in_sentence must be >= min_words_in_sentence" + assert ( + self.max_words_in_sentence >= self.min_words_in_sentence + ), "max_words_in_sentence must be >= min_words_in_sentence" class SentenceReorderingDataset(ProceduralDataset): @@ -35,7 +42,9 @@ class SentenceReorderingDataset(ProceduralDataset): self.sentences = [ sentence.strip() for sentence in re.findall(r"[^.!?]+[.!?]", text) # Changed pattern to include the ending punctuation - if self.config.min_words_in_sentence <= len(re.findall(r"\b\w+\b", sentence)) <= self.config.max_words_in_sentence + if self.config.min_words_in_sentence + <= len(re.findall(r"\b\w+\b", sentence)) + <= self.config.max_words_in_sentence ] def _generate_sentence_dataset(self, sentence: str, seed: int, idx: int, shuffle=True): @@ -66,22 +75,22 @@ class SentenceReorderingDataset(ProceduralDataset): sentence_dataset = self._generate_sentence_dataset(rng.choice(self.sentences), self.seed, idx) # Ensure only 'input' and 'goal' keys are present - if set(sentence_dataset.keys()) != {'input', 'goal'}: + if set(sentence_dataset.keys()) != {"input", "goal"}: raise KeyError("The dictionary must contain only 'input' and 'goal' keys") - + # Solve the task by sorting words to match the goal sentence - input_words = sentence_dataset['input'].split() + input_words = sentence_dataset["input"].split() question = " ".join(input_words) - goal_words = sentence_dataset['goal'].split() + goal_words = sentence_dataset["goal"].split() solved_sentence = " ".join(sorted(input_words, key=lambda word: goal_words.index(word))) # Check for length of alphanumeric characters in the solved sentence word_count = len(re.findall(r"\b\w+\b", solved_sentence)) - return { "question": f"Restore the correct order of words in the following sentence: {question}", "answer": solved_sentence, "metadata": {"word_count": word_count}, } - + + register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index 5e32d031..e8905e33 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -12,8 +12,9 @@ from ..factory import ProceduralDataset, register_dataset class TextTransformation(str, Enum): """Text transformation options""" + LOWERCASE = "lowercase" - UPPERCASE = "uppercase" + UPPERCASE = "uppercase" ORIGINAL = "original" RANDOMCASE = "randomcase" @@ -21,6 +22,7 @@ class TextTransformation(str, Enum): @dataclass class WordSortingConfig: """Configuration for word sorting task generation""" + min_words: int = 3 # Minimum words to sort max_words: int = 10 # Maximum words to sort min_word_length: int = 3 # Minimum word length @@ -43,14 +45,17 @@ class WordSortingDataset(ProceduralDataset): def __init__(self, config: WordSortingConfig): super().__init__(config=config, seed=config.seed, size=config.size) - + # Load and preprocess text text = read_data_file("in_the_year_2889.txt") # Extract unique words within length constraints - self.words = list(set( - word for word in re.findall(r'\b\w+\b', text) - if self.config.min_word_length <= len(word) <= self.config.max_word_length - )) + self.words = list( + set( + word + for word in re.findall(r"\b\w+\b", text) + if self.config.min_word_length <= len(word) <= self.config.max_word_length + ) + ) def _transform_word(self, word: str, rng: Random) -> str: """Apply configured transformation to word""" @@ -59,19 +64,18 @@ class WordSortingDataset(ProceduralDataset): elif self.config.transformation == TextTransformation.UPPERCASE: return word.upper() elif self.config.transformation == TextTransformation.RANDOMCASE: - return ''.join(c.upper() if rng.choice([True, False]) else c.lower() - for c in word) + return "".join(c.upper() if rng.choice([True, False]) else c.lower() for c in word) return word # ORIGINAL case def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]: """Generate list of words and their transformed versions""" count = rng.randint(self.config.min_words, self.config.max_words) - + # Select random words selected_words = rng.sample(self.words, count) # Apply transformation transformed_words = [self._transform_word(word, rng) for word in selected_words] - + return selected_words, transformed_words def __getitem__(self, idx: int) -> dict: @@ -97,7 +101,7 @@ class WordSortingDataset(ProceduralDataset): "transformed_words": transformed_words, "direction": direction, "transformation": self.config.transformation, - "sorted_words": answer + "sorted_words": answer, }, } diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index cfefc422..a4766ebc 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -82,17 +82,16 @@ class FractionSimplificationDataset(ProceduralDataset): def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: """Format a fraction in various styles""" - match style: - case "plain": - return f"{num}/{den}" - case "latex_inline": - return f"${num}/{den}$" - case "latex_frac": - return f"$\\frac{{{num}}}{{{den}}}$" - case "latex_dfrac": - return f"$\\dfrac{{{num}}}{{{den}}}$" - case _: - raise ValueError(f"Unknown fraction style: {style}") + if style == "plain": + return f"{num}/{den}" + elif style == "latex_inline": + return f"${num}/{den}$" + elif style == "latex_frac": + return f"$\\frac{{{num}}}{{{den}}}$" + elif style == "latex_dfrac": + return f"$\\dfrac{{{num}}}{{{den}}}$" + else: + raise ValueError(f"Unknown fraction style: {style}") def __getitem__(self, idx: int) -> dict: """Generate a single fraction simplification task""" diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 42069423..063a8639 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -1,12 +1,12 @@ import random from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from typing import Dict, List, Optional, Tuple from ..factory import ProceduralDataset, register_dataset -class Color(StrEnum): +class Color(str, Enum): RED = "red" GREEN = "green" BLUE = "blue" @@ -25,7 +25,7 @@ class Color(StrEnum): VIOLET = "violet" -class Side(StrEnum): +class Side(Enum): TOP = "top" RIGHT = "right" FRONT = "front" diff --git a/reasoning_gym/cognition/number_sequences.py b/reasoning_gym/cognition/number_sequences.py index bac6a18a..1ef2a658 100644 --- a/reasoning_gym/cognition/number_sequences.py +++ b/reasoning_gym/cognition/number_sequences.py @@ -1,12 +1,12 @@ from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from random import Random from typing import List, Optional from ..factory import ProceduralDataset, register_dataset -class Operation(StrEnum): +class Operation(Enum): """Basic mathematical operations that can be composed""" ADD = "+" diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index e2c10911..66138ce5 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,18 +1,18 @@ import random from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from itertools import count from typing import Dict, List, Optional, Set, Tuple from ..factory import ProceduralDataset, register_dataset -class Gender(StrEnum): +class Gender(Enum): MALE = "male" FEMALE = "female" -class Relationship(StrEnum): +class Relationship(Enum): MOTHER = "mother" FATHER = "father" SISTER = "sister" diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index 395c919f..03479e73 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -1,14 +1,14 @@ """Propositional logic task generator""" from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from random import Random from typing import Any, List, Optional, Set from ..factory import ProceduralDataset, register_dataset -class Operator(StrEnum): +class Operator(Enum): """Basic logical operators""" AND = "∧" diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py index 18d057a8..9ed5b4da 100644 --- a/tests/test_sentence_reordering.py +++ b/tests/test_sentence_reordering.py @@ -1,17 +1,18 @@ import pytest -from reasoning_gym.algorithmic.sentence_reordering import ( - SentenceReorderingConfig, - SentenceReorderingDataset, -) + +from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset + @pytest.fixture def config(): return SentenceReorderingConfig(min_words_in_sentence=5, max_words_in_sentence=5, seed=42, size=10) + @pytest.fixture def dataset(config): return SentenceReorderingDataset(config=config) + def test_config_validation(config): # Test that the config validation does not raise any exceptions try: @@ -19,6 +20,7 @@ def test_config_validation(config): except Exception as e: pytest.fail(f"Config validation raised an exception: {e}") + def test_generate_sentence_dataset(dataset): sentence = "This is a test sentence for reordering" result = dataset._generate_sentence_dataset(sentence, seed=42, idx=0, shuffle=True) @@ -27,12 +29,15 @@ def test_generate_sentence_dataset(dataset): assert result["input"] != result["goal"] assert sorted(result["input"].split()) == sorted(result["goal"].split()) + def test_getitem(dataset, config): item = dataset[0] assert "question" in item assert "answer" in item assert "metadata" in item assert item["metadata"]["word_count"] >= config.min_words_in_sentence + assert item["metadata"]["word_count"] <= config.max_words_in_sentence + def test_key_error_in_getitem(dataset): # Modify the dataset to include an incorrect key diff --git a/tests/test_word_sorting.py b/tests/test_word_sorting.py index f039c0ed..bb673883 100644 --- a/tests/test_word_sorting.py +++ b/tests/test_word_sorting.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.word_sorting import WordSortingConfig, WordSortingDataset, TextTransformation +from reasoning_gym.algorithmic.word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset def test_word_sorting_config_validation(): @@ -38,7 +38,7 @@ def test_word_sorting_transformations(): """Test different text transformations""" seed = 42 size = 5 - + # Test LOWERCASE config = WordSortingConfig(transformation=TextTransformation.LOWERCASE, seed=seed, size=size) dataset = WordSortingDataset(config) @@ -64,14 +64,7 @@ def test_word_sorting_transformations(): def test_word_sorting_dataset_items(): """Test basic properties of generated items""" - config = WordSortingConfig( - min_words=3, - max_words=6, - min_word_length=3, - max_word_length=8, - size=10, - seed=42 - ) + config = WordSortingConfig(min_words=3, max_words=6, min_word_length=3, max_word_length=8, size=10, seed=42) dataset = WordSortingDataset(config) for i in range(len(dataset)):