fix unit tests, lower python dependency to 3.9

This commit is contained in:
Andreas Koepf 2025-01-26 16:55:17 +01:00
parent 98e9c7e55f
commit ecbb155184
11 changed files with 66 additions and 56 deletions

View file

@ -10,7 +10,7 @@ authors = [
]
description = "A library of procedural dataset generators for training reasoning models"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.9"
dependencies = ["sympy>=1.13.1"]
classifiers = [
"Programming Language :: Python :: 3",

View file

@ -15,7 +15,7 @@ from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import WordSortingConfig, WordSortingDataset, TextTransformation
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
__all__ = [
"SpellBackwardConfig",

View file

@ -8,9 +8,11 @@ from typing import List, Optional
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@dataclass
class SentenceReorderingConfig:
"""Configuration for sentence reordering task generation"""
min_words_in_sentence: int = 3
max_words_in_sentence: int = 20
seed: Optional[int] = None
@ -19,7 +21,12 @@ class SentenceReorderingConfig:
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_words_in_sentence > 0, "min_words_in_sentence must be positive"
assert self.max_words_in_sentence >= self.min_words_in_sentence, "max_words_in_sentence must be >= min_words_in_sentence"
assert (
self.max_words_in_sentence >= self.min_words_in_sentence
), "max_words_in_sentence must be >= min_words_in_sentence"
assert (
self.max_words_in_sentence >= self.min_words_in_sentence
), "max_words_in_sentence must be >= min_words_in_sentence"
class SentenceReorderingDataset(ProceduralDataset):
@ -35,7 +42,9 @@ class SentenceReorderingDataset(ProceduralDataset):
self.sentences = [
sentence.strip()
for sentence in re.findall(r"[^.!?]+[.!?]", text) # Changed pattern to include the ending punctuation
if self.config.min_words_in_sentence <= len(re.findall(r"\b\w+\b", sentence)) <= self.config.max_words_in_sentence
if self.config.min_words_in_sentence
<= len(re.findall(r"\b\w+\b", sentence))
<= self.config.max_words_in_sentence
]
def _generate_sentence_dataset(self, sentence: str, seed: int, idx: int, shuffle=True):
@ -66,22 +75,22 @@ class SentenceReorderingDataset(ProceduralDataset):
sentence_dataset = self._generate_sentence_dataset(rng.choice(self.sentences), self.seed, idx)
# Ensure only 'input' and 'goal' keys are present
if set(sentence_dataset.keys()) != {'input', 'goal'}:
if set(sentence_dataset.keys()) != {"input", "goal"}:
raise KeyError("The dictionary must contain only 'input' and 'goal' keys")
# Solve the task by sorting words to match the goal sentence
input_words = sentence_dataset['input'].split()
input_words = sentence_dataset["input"].split()
question = " ".join(input_words)
goal_words = sentence_dataset['goal'].split()
goal_words = sentence_dataset["goal"].split()
solved_sentence = " ".join(sorted(input_words, key=lambda word: goal_words.index(word)))
# Check for length of alphanumeric characters in the solved sentence
word_count = len(re.findall(r"\b\w+\b", solved_sentence))
return {
"question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence,
"metadata": {"word_count": word_count},
}
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)

View file

@ -12,8 +12,9 @@ from ..factory import ProceduralDataset, register_dataset
class TextTransformation(str, Enum):
"""Text transformation options"""
LOWERCASE = "lowercase"
UPPERCASE = "uppercase"
UPPERCASE = "uppercase"
ORIGINAL = "original"
RANDOMCASE = "randomcase"
@ -21,6 +22,7 @@ class TextTransformation(str, Enum):
@dataclass
class WordSortingConfig:
"""Configuration for word sorting task generation"""
min_words: int = 3 # Minimum words to sort
max_words: int = 10 # Maximum words to sort
min_word_length: int = 3 # Minimum word length
@ -43,14 +45,17 @@ class WordSortingDataset(ProceduralDataset):
def __init__(self, config: WordSortingConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract unique words within length constraints
self.words = list(set(
word for word in re.findall(r'\b\w+\b', text)
if self.config.min_word_length <= len(word) <= self.config.max_word_length
))
self.words = list(
set(
word
for word in re.findall(r"\b\w+\b", text)
if self.config.min_word_length <= len(word) <= self.config.max_word_length
)
)
def _transform_word(self, word: str, rng: Random) -> str:
"""Apply configured transformation to word"""
@ -59,19 +64,18 @@ class WordSortingDataset(ProceduralDataset):
elif self.config.transformation == TextTransformation.UPPERCASE:
return word.upper()
elif self.config.transformation == TextTransformation.RANDOMCASE:
return ''.join(c.upper() if rng.choice([True, False]) else c.lower()
for c in word)
return "".join(c.upper() if rng.choice([True, False]) else c.lower() for c in word)
return word # ORIGINAL case
def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]:
"""Generate list of words and their transformed versions"""
count = rng.randint(self.config.min_words, self.config.max_words)
# Select random words
selected_words = rng.sample(self.words, count)
# Apply transformation
transformed_words = [self._transform_word(word, rng) for word in selected_words]
return selected_words, transformed_words
def __getitem__(self, idx: int) -> dict:
@ -97,7 +101,7 @@ class WordSortingDataset(ProceduralDataset):
"transformed_words": transformed_words,
"direction": direction,
"transformation": self.config.transformation,
"sorted_words": answer
"sorted_words": answer,
},
}

View file

@ -82,17 +82,16 @@ class FractionSimplificationDataset(ProceduralDataset):
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles"""
match style:
case "plain":
return f"{num}/{den}"
case "latex_inline":
return f"${num}/{den}$"
case "latex_frac":
return f"$\\frac{{{num}}}{{{den}}}$"
case "latex_dfrac":
return f"$\\dfrac{{{num}}}{{{den}}}$"
case _:
raise ValueError(f"Unknown fraction style: {style}")
if style == "plain":
return f"{num}/{den}"
elif style == "latex_inline":
return f"${num}/{den}$"
elif style == "latex_frac":
return f"$\\frac{{{num}}}{{{den}}}$"
elif style == "latex_dfrac":
return f"$\\dfrac{{{num}}}{{{den}}}$"
else:
raise ValueError(f"Unknown fraction style: {style}")
def __getitem__(self, idx: int) -> dict:
"""Generate a single fraction simplification task"""

View file

@ -1,12 +1,12 @@
import random
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from typing import Dict, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
class Color(StrEnum):
class Color(str, Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
@ -25,7 +25,7 @@ class Color(StrEnum):
VIOLET = "violet"
class Side(StrEnum):
class Side(Enum):
TOP = "top"
RIGHT = "right"
FRONT = "front"

View file

@ -1,12 +1,12 @@
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from random import Random
from typing import List, Optional
from ..factory import ProceduralDataset, register_dataset
class Operation(StrEnum):
class Operation(Enum):
"""Basic mathematical operations that can be composed"""
ADD = "+"

View file

@ -1,18 +1,18 @@
import random
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from itertools import count
from typing import Dict, List, Optional, Set, Tuple
from ..factory import ProceduralDataset, register_dataset
class Gender(StrEnum):
class Gender(Enum):
MALE = "male"
FEMALE = "female"
class Relationship(StrEnum):
class Relationship(Enum):
MOTHER = "mother"
FATHER = "father"
SISTER = "sister"

View file

@ -1,14 +1,14 @@
"""Propositional logic task generator"""
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from random import Random
from typing import Any, List, Optional, Set
from ..factory import ProceduralDataset, register_dataset
class Operator(StrEnum):
class Operator(Enum):
"""Basic logical operators"""
AND = ""

View file

@ -1,17 +1,18 @@
import pytest
from reasoning_gym.algorithmic.sentence_reordering import (
SentenceReorderingConfig,
SentenceReorderingDataset,
)
from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
@pytest.fixture
def config():
return SentenceReorderingConfig(min_words_in_sentence=5, max_words_in_sentence=5, seed=42, size=10)
@pytest.fixture
def dataset(config):
return SentenceReorderingDataset(config=config)
def test_config_validation(config):
# Test that the config validation does not raise any exceptions
try:
@ -19,6 +20,7 @@ def test_config_validation(config):
except Exception as e:
pytest.fail(f"Config validation raised an exception: {e}")
def test_generate_sentence_dataset(dataset):
sentence = "This is a test sentence for reordering"
result = dataset._generate_sentence_dataset(sentence, seed=42, idx=0, shuffle=True)
@ -27,12 +29,15 @@ def test_generate_sentence_dataset(dataset):
assert result["input"] != result["goal"]
assert sorted(result["input"].split()) == sorted(result["goal"].split())
def test_getitem(dataset, config):
item = dataset[0]
assert "question" in item
assert "answer" in item
assert "metadata" in item
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
def test_key_error_in_getitem(dataset):
# Modify the dataset to include an incorrect key

View file

@ -2,7 +2,7 @@
import pytest
from reasoning_gym.algorithmic.word_sorting import WordSortingConfig, WordSortingDataset, TextTransformation
from reasoning_gym.algorithmic.word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
def test_word_sorting_config_validation():
@ -38,7 +38,7 @@ def test_word_sorting_transformations():
"""Test different text transformations"""
seed = 42
size = 5
# Test LOWERCASE
config = WordSortingConfig(transformation=TextTransformation.LOWERCASE, seed=seed, size=size)
dataset = WordSortingDataset(config)
@ -64,14 +64,7 @@ def test_word_sorting_transformations():
def test_word_sorting_dataset_items():
"""Test basic properties of generated items"""
config = WordSortingConfig(
min_words=3,
max_words=6,
min_word_length=3,
max_word_length=8,
size=10,
seed=42
)
config = WordSortingConfig(min_words=3, max_words=6, min_word_length=3, max_word_length=8, size=10, seed=42)
dataset = WordSortingDataset(config)
for i in range(len(dataset)):