diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index eb0978bd..c1f62654 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -60,14 +60,34 @@ class BaseConversionDataset(ProceduralDataset): value, source_base, target_base = self._generate_conversion(rng) # Convert decimal to source base representation - source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip() - if source_base not in (2, 16): - source_repr = format(value, f"{source_base}x").lower().strip() + if source_base == 16: + source_repr = format(value, 'x') + elif source_base == 2: + source_repr = format(value, 'b') + else: + # Manual conversion for other bases + n = value + digits = [] + while n: + digits.append(int(n % source_base)) + n //= source_base + source_repr = ''.join(str(d) if d < 10 else chr(ord('a') + d - 10) + for d in reversed(digits) or [0]) # Convert decimal to target base for answer - target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip() - if target_base not in (2, 16): - target_repr = format(value, f"{target_base}x").lower().strip() + if target_base == 16: + target_repr = format(value, 'x') + elif target_base == 2: + target_repr = format(value, 'b') + else: + # Manual conversion for other bases + n = value + digits = [] + while n: + digits.append(int(n % target_base)) + n //= target_base + target_repr = ''.join(str(d) if d < 10 else chr(ord('a') + d - 10) + for d in reversed(digits) or [0]) source_name = self._format_base_name(source_base) target_name = self._format_base_name(target_base) diff --git a/tests/test_base_conversion.py b/tests/test_base_conversion.py index 7c8edf1e..dced77c4 100644 --- a/tests/test_base_conversion.py +++ b/tests/test_base_conversion.py @@ -83,6 +83,32 @@ def test_base_conversion_dataset_iteration(): assert items == list(dataset) +def test_base_conversion_validity(): + """Test that generated numbers are valid for their bases""" + config = BaseConversionConfig( + min_base=2, + max_base=36, + min_value=0, + max_value=1000, + size=100, + seed=42 + ) + dataset = BaseConversionDataset(config) + + def is_valid_for_base(num_str: str, base: int) -> bool: + valid_chars = "0123456789abcdefghijklmnopqrstuvwxyz"[:base] + return all(c in valid_chars for c in num_str.lower()) + + for i in range(len(dataset)): + item = dataset[i] + assert is_valid_for_base(item["metadata"]["source_repr"], + item["metadata"]["source_base"]), \ + f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}" + assert is_valid_for_base(item["metadata"]["target_repr"], + item["metadata"]["target_base"]), \ + f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}" + + def test_base_conversion_special_bases(): """Test conversion between special bases (binary, hex)""" config = BaseConversionConfig(