diff --git a/reasoning_gym/exercises/algorithmic/__init__.py b/reasoning_gym/exercises/algorithmic/__init__.py new file mode 100644 index 00000000..5034eecd --- /dev/null +++ b/reasoning_gym/exercises/algorithmic/__init__.py @@ -0,0 +1,34 @@ +""" +Algorithmic tasks for training reasoning capabilities: +- Text processing +- Counting +- Sorting +- Pattern matching +""" + +from .base_conversion import BaseConversionExercise +# from .caesar_cipher import CaesarCipherExercise +# from .letter_counting import LetterCountingExercise +# from .letter_jumble import LetterJumbleExercise +# from .number_filtering import NumberFilteringExercise +# from .number_sorting import NumberSortingExercise +# from .sentence_reordering import SentenceReorderingExercise +# from .spell_backward import SpellBackwardExercise +# from .word_ladder import WordLadderExercise +# from .word_sequence_reversal import WordSequenceReversalExercise +# from .word_sorting import TextTransformation, WordSortingExercise + +__all__ = [ + # "SpellBackwardDataset", + "BaseConversionExercise", + # "CaesarCipherDataset", + # "LetterCountingDataset", + # "LetterJumbleDataset", + # "NumberFilteringDataset", + # "NumberSortingDataset", + # "SentenceReorderingDataset", + # "WordSequenceReversalDataset", + # "WordSortingDataset", + # "TextTransformation", + # "WordLadderDataset", +] diff --git a/reasoning_gym/exercises/algorithmic/base_conversion.py b/reasoning_gym/exercises/algorithmic/base_conversion.py new file mode 100644 index 00000000..036860a2 --- /dev/null +++ b/reasoning_gym/exercises/algorithmic/base_conversion.py @@ -0,0 +1,104 @@ +"""Base conversion exercise that converts numbers between different bases.""" + +from typing import Dict, Any + +class BaseConversionExercise: + """Exercise generator for base conversion problems.""" + + def __init__(self): + self.curriculum = None + + def generate(self, curriculum: Any) -> Dict[str, Any]: + """ + Generate a base conversion problem using the curriculum. + + Returns: + Dict containing: + - question: str (e.g. "Convert the binary number 1010 to hexadecimal") + - answer: str (the converted number in target base) + - metadata: dict with details (value, source_base, target_base, etc.) + """ + self.curriculum = curriculum + template = curriculum.get_template(curriculum.rng) + return template.eval(self, curriculum.rng) + + def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """ + Parse the template metadata into structured data. + + The metadata structure from the curriculum: + { + "source_value": {"val": str}, # e.g. "1010" or "a5" + "source_base": {"base": str}, # e.g. "binary" or "base-3" + "target_base": {"base": str, "hint": str}, # e.g. "hexadecimal" or "base-8" with optional hint + } + + Returns: + Dictionary containing: + - source_value: str (value to convert) + - source_base: int (base to convert from) + - target_base: int (base to convert to) + """ + def parse_base_name(name: str) -> int: + """Convert base name to numeric value.""" + name = name.lower() + if name == "binary": + return 2 + elif name == "octal": + return 8 + elif name == "decimal": + return 10 + elif name == "hexadecimal": + return 16 + elif name.startswith("base-"): + return int(name[5:]) + raise ValueError(f"Unknown base name: {name}") + + return { + "source_value": metadata["source_value"]["val"], + "source_base": parse_base_name(metadata["source_base"]["base"]), + "target_base": parse_base_name(metadata["target_base"]["base"]) + } + + def _evaluate_expression(self, parsed: Dict[str, Any]) -> str: + """ + Convert the number between bases. + + Args: + parsed: Dictionary containing: + - source_base: int (base to convert from) + - target_base: int (base to convert to) + - source_value: str (value to convert) + Returns: + String representation of the number in target base + """ + try: + # Convert source value to decimal, handling letter digits + source_value = parsed["source_value"].lower() + decimal_value = 0 + for digit in source_value: + if digit.isdigit(): + digit_val = int(digit) + else: + digit_val = ord(digit) - ord('a') + 10 + if digit_val >= parsed["source_base"]: + raise ValueError(f"Digit {digit} is invalid for base {parsed['source_base']}") + decimal_value = decimal_value * parsed["source_base"] + digit_val + + # Convert decimal to target base + if decimal_value == 0: + return "0" + + # Manual conversion for all bases + digits = [] + n = decimal_value + while n: + digits.append(int(n % parsed["target_base"])) + n //= parsed["target_base"] + # Convert to string with letters for digits > 9 + result = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) + for d in reversed(digits)) + return result + + except ValueError as e: + return f"Error converting number: {str(e)}" diff --git a/tests/exercises/algorithmic/test_base_conversion.py b/tests/exercises/algorithmic/test_base_conversion.py new file mode 100644 index 00000000..a3dc307a --- /dev/null +++ b/tests/exercises/algorithmic/test_base_conversion.py @@ -0,0 +1,422 @@ +"""Unit tests for the base conversion exercise.""" + +from enum import verify +from reasoning_gym.curricula.algorithmic.base_conversion_curriculum import BaseConversionCurriculum +from reasoning_gym.exercises.algorithmic.base_conversion import BaseConversionExercise +import unittest +import random +from collections import defaultdict + +class TestBaseConversionParsing(unittest.TestCase): + """Test parsing of base conversion metadata""" + + def setUp(self): + self.exercise = BaseConversionExercise() + + def test_parse_expression_basic(self): + """Test parsing of basic base conversion metadata""" + test_metadata = { + "source_value": {"val": "1010"}, + "source_base": {"base": "binary"}, + "target_base": {"base": "hexadecimal", "hint": ""} + } + parsed = self.exercise._parse_expression(test_metadata) + self.assertEqual(parsed["source_value"], "1010") + self.assertEqual(parsed["source_base"], 2) + self.assertEqual(parsed["target_base"], 16) + + def test_parse_base_names(self): + """Test parsing of different base names""" + test_cases = [ + ({"base": "binary"}, 2), + ({"base": "octal"}, 8), + ({"base": "decimal"}, 10), + ({"base": "hexadecimal"}, 16), + ({"base": "base-3"}, 3), + ({"base": "base-36"}, 36) + ] + for base_dict, expected in test_cases: + metadata = { + "source_value": {"val": "0"}, + "source_base": base_dict, + "target_base": {"base": "decimal", "hint": ""} + } + parsed = self.exercise._parse_expression(metadata) + self.assertEqual(parsed["source_base"], expected) + + def test_invalid_base_name(self): + """Test handling of invalid base names""" + metadata = { + "source_value": {"val": "0"}, + "source_base": {"base": "invalid"}, + "target_base": {"base": "decimal", "hint": ""} + } + with self.assertRaises(ValueError): + self.exercise._parse_expression(metadata) + + def test_parse_with_hints(self): + """Test parsing with different hint configurations""" + test_cases = [ + ({"hint": ""}, ""), + ({"hint": " (use lowercase letters a-z for digits above 9)"}, " (use lowercase letters a-z for digits above 9)"), + ({"hint": " (hint: convert to decimal first)"}, " (hint: convert to decimal first)") + ] + for hint_dict, expected in test_cases: + metadata = { + "source_value": {"val": "0"}, + "source_base": {"base": "binary"}, + "target_base": {"base": "hexadecimal", "hint": hint_dict["hint"]} + } + parsed = self.exercise._parse_expression(metadata) + self.assertEqual(parsed["source_base"], 2) + self.assertEqual(parsed["target_base"], 16) + +class TestBaseConversionEvaluation(unittest.TestCase): + """Test evaluation of base conversion problems""" + + def setUp(self): + self.exercise = BaseConversionExercise() + + def test_binary_to_decimal(self): + """Test binary to decimal conversion""" + test_cases = [ + ("1010", "10"), # 10 in decimal + ("1111", "15"), # 15 in decimal + ("10000", "16"), # 16 in decimal + ("0", "0"), # 0 in any base is 0 + ("1", "1") # 1 in any base is 1 + ] + for binary, expected in test_cases: + parsed = { + "source_value": binary, + "source_base": 2, + "target_base": 10 + } + result = self.exercise._evaluate_expression(parsed) + self.assertEqual(result, expected) + + def test_decimal_to_hex(self): + """Test decimal to hexadecimal conversion""" + test_cases = [ + ("255", "ff"), # Max 8-bit value + ("16", "10"), # Power of 16 + ("10", "a"), # Single hex digit + ("0", "0"), # Zero + ("4096", "1000") # Power of 16 + ] + for decimal, expected in test_cases: + parsed = { + "source_value": decimal, + "source_base": 10, + "target_base": 16 + } + result = self.exercise._evaluate_expression(parsed) + self.assertEqual(result, expected) + + def test_hex_to_octal(self): + """Test hexadecimal to octal conversion""" + test_cases = [ + ("ff", "377"), # Max 8-bit value + ("10", "20"), # Simple conversion + ("a5", "245"), # Mixed digits and letters + ("0", "0"), # Zero + ("100", "400") # Power of 16 + ] + for hex_val, expected in test_cases: + parsed = { + "source_value": hex_val, + "source_base": 16, + "target_base": 8 + } + result = self.exercise._evaluate_expression(parsed) + self.assertEqual(result, expected) + + def test_zero_value(self): + """Test conversion of zero in any base""" + bases = [2, 3, 8, 10, 16, 36] # Test more bases + for source_base in bases: + for target_base in bases: + parsed = { + "source_value": "0", + "source_base": source_base, + "target_base": target_base + } + result = self.exercise._evaluate_expression(parsed) + self.assertEqual(result, "0") + + def test_invalid_digits(self): + """Test handling of invalid digits for given base""" + test_cases = [ + ("123", 2), # Invalid binary + ("9", 8), # Invalid octal + ("g", 16), # Invalid hex + ("z", 35) # Invalid for base-35 + ] + for value, base in test_cases: + parsed = { + "source_value": value, + "source_base": base, + "target_base": 10 + } + result = self.exercise._evaluate_expression(parsed) + self.assertTrue(result.startswith("Error")) + + def test_edge_cases(self): + """Test edge cases and boundary values""" + test_cases = [ + # Max values for different bases + ("11111111", 2, 16, "ff"), # Max 8-bit binary to hex + ("77777777", 8, 16, "ffffff"), # Large octal to hex + ("ffffff", 16, 2, "111111111111111111111111"), # Large hex to binary + # Single digits + ("1", 2, 36, "1"), + ("z", 36, 2, "100011"), # Corrected: 'z' in base-36 is 35, which is 100011 in binary + # Alternating patterns + ("101010", 2, 8, "52"), + ("aaaaaa", 16, 10, "11184810") + ] + for value, source_base, target_base, expected in test_cases: + parsed = { + "source_value": value, + "source_base": source_base, + "target_base": target_base + } + result = self.exercise._evaluate_expression(parsed) + self.assertEqual(result, expected) + +class TestBaseConversionGeneration(unittest.TestCase): + """Test problem generation""" + + def setUp(self): + self.curriculum = BaseConversionCurriculum() + self.exercise = BaseConversionExercise() + self.rng = random.Random(42) + self.curriculum.rng = self.rng + + def test_problem_structure(self): + """Test that generated problems have the correct structure""" + problem = self.exercise.generate(self.curriculum) + + # Check basic structure + self.assertIn("question", problem) + self.assertIn("answer", problem) + self.assertIn("metadata", problem) + + # Check metadata structure + metadata = problem["metadata"] + self.assertEqual(metadata["type"], "direct") + self.assertIn("executed_parts", metadata) + executed_parts = metadata["executed_parts"] + self.assertIn("source_value", executed_parts) + self.assertIn("source_base", executed_parts) + self.assertIn("target_base", executed_parts) + + def test_value_ranges(self): + """Test that generated values are within expected ranges""" + # Test all value levels + level_max_values = {0: 100, 1: 1000, 2: 10000} + + for level, max_value in level_max_values.items(): + self.curriculum.set_attr_level("value", level) + problem = self.exercise.generate(self.curriculum) + decimal_val = int(problem["metadata"]["executed_parts"]["source_value"], + problem["metadata"]["executed_parts"]["source_base"]) + self.assertLessEqual(decimal_val, max_value) + + def test_base_ranges(self): + """Test that bases are within expected ranges""" + # Test all base range levels + level_max_bases = {0: 16, 1: 26, 2: 36} + + for level, max_base in level_max_bases.items(): + self.curriculum.set_attr_level("base_range", level) + problem = self.exercise.generate(self.curriculum) + source_base = problem["metadata"]["executed_parts"]["source_base"] + target_base = problem["metadata"]["executed_parts"]["target_base"] + self.assertLessEqual(source_base, max_base) + self.assertLessEqual(target_base, max_base) + self.assertGreaterEqual(source_base, 2) + self.assertGreaterEqual(target_base, 2) + + def test_template_variation(self): + """Test that different templates are used""" + templates_seen = set() + num_samples = 100 + + for _ in range(num_samples): + problem = self.exercise.generate(self.curriculum) + templates_seen.add(problem["question"].split(":")[0]) # Get the question pattern + + self.assertGreater(len(templates_seen), 1, "Not enough template variation") + +class TestBaseConversionComprehensive(unittest.TestCase): + """Comprehensive tests for base conversion""" + + def setUp(self): + self.curriculum = BaseConversionCurriculum() + self.exercise = BaseConversionExercise() + self.rng = random.Random(42) + self.curriculum.rng = self.rng + + def _extract_base(self, text): + """Helper method to extract base from problem text.""" + if "binary" in text.lower(): + return 2 + if "octal" in text.lower(): + return 8 + if "decimal" in text.lower(): + return 10 + if "hexadecimal" in text.lower(): + return 16 + + # Try to find base-N pattern + import re + match = re.search(r'base-(\d+)', text.lower()) + if match: + return int(match.group(1)) + return None + + def test_all_base_combinations(self): + """Test conversion between all possible base combinations""" + bases = [2, 8, 10, 16, 36] # Test common bases + test_values = ["10", "ff", "xyz", "777", "42"] # Test values + + for source_base in bases: + for target_base in bases: + for value in test_values: + try: + # Skip if value is invalid for source base + int(value, min(source_base, 36)) + except ValueError: + continue + + parsed = { + "source_value": value, + "source_base": source_base, + "target_base": target_base + } + result = self.exercise._evaluate_expression(parsed) + + # Verify result by converting back + try: + decimal = int(result, target_base) + original = int(value, source_base) + self.assertEqual(decimal, original) + except ValueError: + self.fail(f"Invalid conversion: {value} from base {source_base} to base {target_base}") + + def test_hint_inclusion(self): + """Test that hints are included appropriately""" + # Test with hints enabled + self.curriculum.set_attr_level("hint", 0) + problem = self.exercise.generate(self.curriculum) + if problem["metadata"]["executed_parts"]["target_base"] > 10: + self.assertIn("use lowercase letters", problem["question"].lower()) + + # Test with hints disabled + self.curriculum.set_attr_level("hint", 1) + problem = self.exercise.generate(self.curriculum) + self.assertNotIn("use lowercase letters", problem["question"].lower()) + + def test_base_names(self): + """Test that base names are used correctly""" + # Test with basic names + self.curriculum.set_attr_level("base_names", 0) + problem = self.exercise.generate(self.curriculum) + question = problem["question"].lower() + self.assertTrue(any(name in question for name in ["binary", "hexadecimal", "base-"])) + + # Test with extended names + self.curriculum.set_attr_level("base_names", 1) + problem = self.exercise.generate(self.curriculum) + question = problem["question"].lower() + self.assertTrue(any(name in question for name in ["octal", "decimal", "base-"])) + + def test_comprehensive_random_evaluation(self): + """Test random evaluation with all base combinations and track statistics.""" + self.rng = random.Random(42) # Fixed seed for reproducibility + self.curriculum.rng = self.rng + + # Track statistics + base_name_usage = defaultdict(int) + source_bases = defaultdict(int) + target_bases = defaultdict(int) + values = [] + hint_count = 0 + total_samples = 1000 + + # Generate test cases + for _ in range(total_samples): + # Set random attribute levels + for attr in ["value", "base_range"]: + self.curriculum.set_attr_level(attr, self.rng.randint(0, 2)) + for attr in ["base_names", "hint"]: + self.curriculum.set_attr_level(attr, self.rng.randint(0, 1)) + + # Generate and evaluate a random problem + problem = self.exercise.generate(self.curriculum) + + # Track statistics + if "binary" in problem["question"].lower(): + base_name_usage["binary"] += 1 + elif "octal" in problem["question"].lower(): + base_name_usage["octal"] += 1 + elif "hexadecimal" in problem["question"].lower(): + base_name_usage["hexadecimal"] += 1 + elif "decimal" in problem["question"].lower(): + base_name_usage["decimal"] += 1 + else: + base_name_usage["other"] += 1 + + # Track source and target bases + metadata = problem["metadata"]["executed_parts"] + source_base = metadata["source_base"] + target_base = metadata["target_base"] + + if source_base: + source_bases[source_base] += 1 + if target_base: + target_bases[target_base] += 1 + + # Track if hints are included + if "(use lowercase letters a-z for digits above 9)" in problem["question"]: + hint_count += 1 + + # Track value statistics + try: + value = int(metadata["source_value"], source_base) + values.append(value) + except ValueError: + pass + + # Print statistics + print("\nBase name usage:") + for name, count in base_name_usage.items(): + print(f" {name}: {count}") + + print("\nSource bases used (35 bases):") + for base in range(2, 37): + if source_bases[base] > 0: + print(f" base-{base}: {source_bases[base]}") + + print("\nTarget bases used (35 bases):") + for base in range(2, 37): + if target_bases[base] > 0: + print(f" base-{base}: {target_bases[base]}") + + print("\nValue statistics:") + if values: + print(f" Min value: {min(values)}") + print(f" Max value: {max(values)}") + print(f" Average value: {sum(values) / len(values):.2f}") + print(f" Total samples with hints: {hint_count} / {total_samples}") + + # verify statistics + self.assertTrue(base_name_usage["hexadecimal"] >= 4, "Hexadecimal base name was not used enough") + self.assertTrue(len(source_bases) >= 10, "Not enough different source bases used") + self.assertTrue(len(target_bases) >= 10, "Not enough different target bases used") + self.assertTrue(hint_count > 0, "No hints were included") + self.assertTrue(hint_count < total_samples, "Too many hints were included") + +if __name__ == '__main__': + unittest.main()