Refactor BaseConversion

This commit is contained in:
EduardDurech 2025-02-09 02:11:59 +00:00
parent 7dce30324b
commit c4f2f6386d
6 changed files with 743 additions and 264 deletions

View file

@ -1,165 +1,422 @@
"""Tests for base conversion task generation"""
"""Unit tests for the base conversion exercise."""
import pytest
from enum import verify
from reasoning_gym.curricula.algorithmic.base_conversion_curriculum import BaseConversionCurriculum
from reasoning_gym.exercises.algorithmic.base_conversion import BaseConversionExercise
import unittest
import random
from collections import defaultdict
from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
class TestBaseConversionParsing(unittest.TestCase):
"""Test parsing of base conversion metadata"""
def setUp(self):
self.exercise = BaseConversionExercise()
def test_base_conversion_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = BaseConversionConfig(min_base=1) # Too small
config.validate()
def test_parse_expression_basic(self):
"""Test parsing of basic base conversion metadata"""
test_metadata = {
"source_value": {"val": "1010"},
"source_base": {"base": "binary"},
"target_base": {"base": "hexadecimal", "hint": ""}
}
parsed = self.exercise._parse_expression(test_metadata)
self.assertEqual(parsed["source_value"], "1010")
self.assertEqual(parsed["source_base"], 2)
self.assertEqual(parsed["target_base"], 16)
with pytest.raises(AssertionError):
config = BaseConversionConfig(min_base=37) # Too large
config.validate()
def test_parse_base_names(self):
"""Test parsing of different base names"""
test_cases = [
({"base": "binary"}, 2),
({"base": "octal"}, 8),
({"base": "decimal"}, 10),
({"base": "hexadecimal"}, 16),
({"base": "base-3"}, 3),
({"base": "base-36"}, 36)
]
for base_dict, expected in test_cases:
metadata = {
"source_value": {"val": "0"},
"source_base": base_dict,
"target_base": {"base": "decimal", "hint": ""}
}
parsed = self.exercise._parse_expression(metadata)
self.assertEqual(parsed["source_base"], expected)
with pytest.raises(AssertionError):
config = BaseConversionConfig(min_base=10, max_base=5) # max < min
config.validate()
def test_invalid_base_name(self):
"""Test handling of invalid base names"""
metadata = {
"source_value": {"val": "0"},
"source_base": {"base": "invalid"},
"target_base": {"base": "decimal", "hint": ""}
}
with self.assertRaises(ValueError):
self.exercise._parse_expression(metadata)
with pytest.raises(AssertionError):
config = BaseConversionConfig(min_value=-1) # Negative not allowed
config.validate()
def test_parse_with_hints(self):
"""Test parsing with different hint configurations"""
test_cases = [
({"hint": ""}, ""),
({"hint": " (use lowercase letters a-z for digits above 9)"}, " (use lowercase letters a-z for digits above 9)"),
({"hint": " (hint: convert to decimal first)"}, " (hint: convert to decimal first)")
]
for hint_dict, expected in test_cases:
metadata = {
"source_value": {"val": "0"},
"source_base": {"base": "binary"},
"target_base": {"base": "hexadecimal", "hint": hint_dict["hint"]}
}
parsed = self.exercise._parse_expression(metadata)
self.assertEqual(parsed["source_base"], 2)
self.assertEqual(parsed["target_base"], 16)
class TestBaseConversionEvaluation(unittest.TestCase):
"""Test evaluation of base conversion problems"""
def test_base_conversion_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = BaseConversionConfig(seed=42, size=10)
dataset1 = BaseConversionDataset(config)
dataset2 = BaseConversionDataset(config)
def setUp(self):
self.exercise = BaseConversionExercise()
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_binary_to_decimal(self):
"""Test binary to decimal conversion"""
test_cases = [
("1010", "10"), # 10 in decimal
("1111", "15"), # 15 in decimal
("10000", "16"), # 16 in decimal
("0", "0"), # 0 in any base is 0
("1", "1") # 1 in any base is 1
]
for binary, expected in test_cases:
parsed = {
"source_value": binary,
"source_base": 2,
"target_base": 10
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, expected)
def test_decimal_to_hex(self):
"""Test decimal to hexadecimal conversion"""
test_cases = [
("255", "ff"), # Max 8-bit value
("16", "10"), # Power of 16
("10", "a"), # Single hex digit
("0", "0"), # Zero
("4096", "1000") # Power of 16
]
for decimal, expected in test_cases:
parsed = {
"source_value": decimal,
"source_base": 10,
"target_base": 16
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, expected)
def test_base_conversion_dataset_items():
"""Test basic properties of generated items"""
config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
dataset = BaseConversionDataset(config)
def test_hex_to_octal(self):
"""Test hexadecimal to octal conversion"""
test_cases = [
("ff", "377"), # Max 8-bit value
("10", "20"), # Simple conversion
("a5", "245"), # Mixed digits and letters
("0", "0"), # Zero
("100", "400") # Power of 16
]
for hex_val, expected in test_cases:
parsed = {
"source_value": hex_val,
"source_base": 16,
"target_base": 8
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, expected)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
def test_zero_value(self):
"""Test conversion of zero in any base"""
bases = [2, 3, 8, 10, 16, 36] # Test more bases
for source_base in bases:
for target_base in bases:
parsed = {
"source_value": "0",
"source_base": source_base,
"target_base": target_base
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, "0")
# Check metadata
assert "decimal_value" in item["metadata"]
assert "source_base" in item["metadata"]
assert "target_base" in item["metadata"]
assert "source_repr" in item["metadata"]
assert "target_repr" in item["metadata"]
def test_invalid_digits(self):
"""Test handling of invalid digits for given base"""
test_cases = [
("123", 2), # Invalid binary
("9", 8), # Invalid octal
("g", 16), # Invalid hex
("z", 35) # Invalid for base-35
]
for value, base in test_cases:
parsed = {
"source_value": value,
"source_base": base,
"target_base": 10
}
result = self.exercise._evaluate_expression(parsed)
self.assertTrue(result.startswith("Error"))
# Verify value range
assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value
def test_edge_cases(self):
"""Test edge cases and boundary values"""
test_cases = [
# Max values for different bases
("11111111", 2, 16, "ff"), # Max 8-bit binary to hex
("77777777", 8, 16, "ffffff"), # Large octal to hex
("ffffff", 16, 2, "111111111111111111111111"), # Large hex to binary
# Single digits
("1", 2, 36, "1"),
("z", 36, 2, "100011"), # Corrected: 'z' in base-36 is 35, which is 100011 in binary
# Alternating patterns
("101010", 2, 8, "52"),
("aaaaaa", 16, 10, "11184810")
]
for value, source_base, target_base, expected in test_cases:
parsed = {
"source_value": value,
"source_base": source_base,
"target_base": target_base
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, expected)
# Verify base range
assert config.min_base <= item["metadata"]["source_base"] <= config.max_base
assert config.min_base <= item["metadata"]["target_base"] <= config.max_base
assert item["metadata"]["source_base"] != item["metadata"]["target_base"]
class TestBaseConversionGeneration(unittest.TestCase):
"""Test problem generation"""
# Verify conversion correctness
decimal_value = item["metadata"]["decimal_value"]
target_base = item["metadata"]["target_base"]
def setUp(self):
self.curriculum = BaseConversionCurriculum()
self.exercise = BaseConversionExercise()
self.rng = random.Random(42)
self.curriculum.rng = self.rng
# Use same conversion logic as implementation
if target_base == 16:
expected = format(decimal_value, "x")
elif target_base == 2:
expected = format(decimal_value, "b")
else:
# Manual conversion for other bases
n = decimal_value
digits = []
while n:
digits.append(int(n % target_base))
n //= target_base
expected = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
assert item["answer"] == expected
def test_problem_structure(self):
"""Test that generated problems have the correct structure"""
problem = self.exercise.generate(self.curriculum)
# Check basic structure
self.assertIn("question", problem)
self.assertIn("answer", problem)
self.assertIn("metadata", problem)
def test_base_conversion_dataset_iteration():
"""Test that iteration respects dataset size"""
config = BaseConversionConfig(size=5, seed=42)
dataset = BaseConversionDataset(config)
# Check metadata structure
metadata = problem["metadata"]
self.assertEqual(metadata["type"], "direct")
self.assertIn("executed_parts", metadata)
executed_parts = metadata["executed_parts"]
self.assertIn("source_value", executed_parts)
self.assertIn("source_base", executed_parts)
self.assertIn("target_base", executed_parts)
items = list(dataset)
assert len(items) == config.size
def test_value_ranges(self):
"""Test that generated values are within expected ranges"""
# Test all value levels
level_max_values = {0: 100, 1: 1000, 2: 10000}
for level, max_value in level_max_values.items():
self.curriculum.set_attr_level("value", level)
problem = self.exercise.generate(self.curriculum)
decimal_val = int(problem["metadata"]["executed_parts"]["source_value"],
problem["metadata"]["executed_parts"]["source_base"])
self.assertLessEqual(decimal_val, max_value)
# Test multiple iterations yield same items
assert items == list(dataset)
def test_base_ranges(self):
"""Test that bases are within expected ranges"""
# Test all base range levels
level_max_bases = {0: 16, 1: 26, 2: 36}
for level, max_base in level_max_bases.items():
self.curriculum.set_attr_level("base_range", level)
problem = self.exercise.generate(self.curriculum)
source_base = problem["metadata"]["executed_parts"]["source_base"]
target_base = problem["metadata"]["executed_parts"]["target_base"]
self.assertLessEqual(source_base, max_base)
self.assertLessEqual(target_base, max_base)
self.assertGreaterEqual(source_base, 2)
self.assertGreaterEqual(target_base, 2)
def test_template_variation(self):
"""Test that different templates are used"""
templates_seen = set()
num_samples = 100
def test_base_conversion_validity():
"""Test that generated numbers are valid for their bases"""
config = BaseConversionConfig(min_base=2, max_base=36, min_value=0, max_value=1000, size=100, seed=42)
dataset = BaseConversionDataset(config)
for _ in range(num_samples):
problem = self.exercise.generate(self.curriculum)
templates_seen.add(problem["question"].split(":")[0]) # Get the question pattern
def is_valid_for_base(num_str: str, base: int) -> bool:
valid_chars = "0123456789abcdefghijklmnopqrstuvwxyz"[:base]
return all(c in valid_chars for c in num_str.lower())
self.assertGreater(len(templates_seen), 1, "Not enough template variation")
for i in range(len(dataset)):
item = dataset[i]
assert is_valid_for_base(
item["metadata"]["source_repr"], item["metadata"]["source_base"]
), f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}"
assert is_valid_for_base(
item["metadata"]["target_repr"], item["metadata"]["target_base"]
), f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}"
class TestBaseConversionComprehensive(unittest.TestCase):
"""Comprehensive tests for base conversion"""
def setUp(self):
self.curriculum = BaseConversionCurriculum()
self.exercise = BaseConversionExercise()
self.rng = random.Random(42)
self.curriculum.rng = self.rng
def test_base_conversion_special_bases():
"""Test conversion between special bases (binary, hex)"""
config = BaseConversionConfig(
min_base=2,
max_base=16,
min_value=0,
max_value=255, # Use small range for predictable results
size=100,
seed=42,
)
dataset = BaseConversionDataset(config)
def _extract_base(self, text):
"""Helper method to extract base from problem text."""
if "binary" in text.lower():
return 2
if "octal" in text.lower():
return 8
if "decimal" in text.lower():
return 10
if "hexadecimal" in text.lower():
return 16
binary_found = False
hex_found = False
# Try to find base-N pattern
import re
match = re.search(r'base-(\d+)', text.lower())
if match:
return int(match.group(1))
return None
for i in range(len(dataset)):
item = dataset[i]
if item["metadata"]["target_base"] == 2:
binary_found = True
# Verify binary format
assert all(c in "01" for c in item["answer"])
elif item["metadata"]["target_base"] == 16:
hex_found = True
# Verify hex format
assert all(c in "0123456789abcdef" for c in item["answer"])
def test_all_base_combinations(self):
"""Test conversion between all possible base combinations"""
bases = [2, 8, 10, 16, 36] # Test common bases
test_values = ["10", "ff", "xyz", "777", "42"] # Test values
assert binary_found, "No binary conversion tasks generated"
assert hex_found, "No hexadecimal conversion tasks generated"
for source_base in bases:
for target_base in bases:
for value in test_values:
try:
# Skip if value is invalid for source base
int(value, min(source_base, 36))
except ValueError:
continue
parsed = {
"source_value": value,
"source_base": source_base,
"target_base": target_base
}
result = self.exercise._evaluate_expression(parsed)
def test_base_conversion_formatting():
"""Test number formatting in different bases"""
config = BaseConversionConfig(
min_base=11, # Force bases that use letters
max_base=36,
min_value=10, # Ensure multi-digit numbers
max_value=1000,
size=10,
seed=42,
)
dataset = BaseConversionDataset(config)
# Verify result by converting back
try:
decimal = int(result, target_base)
original = int(value, source_base)
self.assertEqual(decimal, original)
except ValueError:
self.fail(f"Invalid conversion: {value} from base {source_base} to base {target_base}")
for i in range(len(dataset)):
item = dataset[i]
# Verify lowercase letters are used
assert item["answer"] == item["answer"].lower()
# Verify no whitespace in answer
assert item["answer"].strip() == item["answer"]
# Verify hint is included for bases > 10
assert "use lowercase letters" in item["question"]
def test_hint_inclusion(self):
"""Test that hints are included appropriately"""
# Test with hints enabled
self.curriculum.set_attr_level("hint", 0)
problem = self.exercise.generate(self.curriculum)
if problem["metadata"]["executed_parts"]["target_base"] > 10:
self.assertIn("use lowercase letters", problem["question"].lower())
# Test with hints disabled
self.curriculum.set_attr_level("hint", 1)
problem = self.exercise.generate(self.curriculum)
self.assertNotIn("use lowercase letters", problem["question"].lower())
def test_base_names(self):
"""Test that base names are used correctly"""
# Test with basic names
self.curriculum.set_attr_level("base_names", 0)
problem = self.exercise.generate(self.curriculum)
question = problem["question"].lower()
self.assertTrue(any(name in question for name in ["binary", "hexadecimal", "base-"]))
# Test with extended names
self.curriculum.set_attr_level("base_names", 1)
problem = self.exercise.generate(self.curriculum)
question = problem["question"].lower()
self.assertTrue(any(name in question for name in ["octal", "decimal", "base-"]))
def test_comprehensive_random_evaluation(self):
"""Test random evaluation with all base combinations and track statistics."""
self.rng = random.Random(42) # Fixed seed for reproducibility
self.curriculum.rng = self.rng
# Track statistics
base_name_usage = defaultdict(int)
source_bases = defaultdict(int)
target_bases = defaultdict(int)
values = []
hint_count = 0
total_samples = 1000
# Generate test cases
for _ in range(total_samples):
# Set random attribute levels
for attr in ["value", "base_range"]:
self.curriculum.set_attr_level(attr, self.rng.randint(0, 2))
for attr in ["base_names", "hint"]:
self.curriculum.set_attr_level(attr, self.rng.randint(0, 1))
# Generate and evaluate a random problem
problem = self.exercise.generate(self.curriculum)
# Track statistics
if "binary" in problem["question"].lower():
base_name_usage["binary"] += 1
elif "octal" in problem["question"].lower():
base_name_usage["octal"] += 1
elif "hexadecimal" in problem["question"].lower():
base_name_usage["hexadecimal"] += 1
elif "decimal" in problem["question"].lower():
base_name_usage["decimal"] += 1
else:
base_name_usage["other"] += 1
# Track source and target bases
metadata = problem["metadata"]["executed_parts"]
source_base = metadata["source_base"]
target_base = metadata["target_base"]
if source_base:
source_bases[source_base] += 1
if target_base:
target_bases[target_base] += 1
# Track if hints are included
if "(use lowercase letters a-z for digits above 9)" in problem["question"]:
hint_count += 1
# Track value statistics
try:
value = int(metadata["source_value"], source_base)
values.append(value)
except ValueError:
pass
# Print statistics
print("\nBase name usage:")
for name, count in base_name_usage.items():
print(f" {name}: {count}")
print("\nSource bases used (35 bases):")
for base in range(2, 37):
if source_bases[base] > 0:
print(f" base-{base}: {source_bases[base]}")
print("\nTarget bases used (35 bases):")
for base in range(2, 37):
if target_bases[base] > 0:
print(f" base-{base}: {target_bases[base]}")
print("\nValue statistics:")
if values:
print(f" Min value: {min(values)}")
print(f" Max value: {max(values)}")
print(f" Average value: {sum(values) / len(values):.2f}")
print(f" Total samples with hints: {hint_count} / {total_samples}")
# verify statistics
self.assertTrue(base_name_usage["hexadecimal"] >= 4, "Hexadecimal base name was not used enough")
self.assertTrue(len(source_bases) >= 10, "Not enough different source bases used")
self.assertTrue(len(target_bases) >= 10, "Not enough different target bases used")
self.assertTrue(hint_count > 0, "No hints were included")
self.assertTrue(hint_count < total_samples, "Too many hints were included")
if __name__ == '__main__':
unittest.main()