diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index c1f62654..afa6200a 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -61,9 +61,9 @@ class BaseConversionDataset(ProceduralDataset): # Convert decimal to source base representation if source_base == 16: - source_repr = format(value, 'x') + source_repr = format(value, "x") elif source_base == 2: - source_repr = format(value, 'b') + source_repr = format(value, "b") else: # Manual conversion for other bases n = value @@ -71,14 +71,13 @@ class BaseConversionDataset(ProceduralDataset): while n: digits.append(int(n % source_base)) n //= source_base - source_repr = ''.join(str(d) if d < 10 else chr(ord('a') + d - 10) - for d in reversed(digits) or [0]) + source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) # Convert decimal to target base for answer if target_base == 16: - target_repr = format(value, 'x') + target_repr = format(value, "x") elif target_base == 2: - target_repr = format(value, 'b') + target_repr = format(value, "b") else: # Manual conversion for other bases n = value @@ -86,8 +85,7 @@ class BaseConversionDataset(ProceduralDataset): while n: digits.append(int(n % target_base)) n //= target_base - target_repr = ''.join(str(d) if d < 10 else chr(ord('a') + d - 10) - for d in reversed(digits) or [0]) + target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) source_name = self._format_base_name(source_base) target_name = self._format_base_name(target_base) diff --git a/tests/test_base_conversion.py b/tests/test_base_conversion.py index 2f12eaa9..8017d74a 100644 --- a/tests/test_base_conversion.py +++ b/tests/test_base_conversion.py @@ -65,12 +65,12 @@ def test_base_conversion_dataset_items(): # Verify conversion correctness decimal_value = item["metadata"]["decimal_value"] target_base = item["metadata"]["target_base"] - + # Use same conversion logic as implementation if target_base == 16: - expected = format(decimal_value, 'x') + expected = format(decimal_value, "x") elif target_base == 2: - expected = format(decimal_value, 'b') + expected = format(decimal_value, "b") else: # Manual conversion for other bases n = decimal_value @@ -78,8 +78,7 @@ def test_base_conversion_dataset_items(): while n: digits.append(int(n % target_base)) n //= target_base - expected = ''.join(str(d) if d < 10 else chr(ord('a') + d - 10) - for d in reversed(digits) or [0]) + expected = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) assert item["answer"] == expected @@ -97,14 +96,7 @@ def test_base_conversion_dataset_iteration(): def test_base_conversion_validity(): """Test that generated numbers are valid for their bases""" - config = BaseConversionConfig( - min_base=2, - max_base=36, - min_value=0, - max_value=1000, - size=100, - seed=42 - ) + config = BaseConversionConfig(min_base=2, max_base=36, min_value=0, max_value=1000, size=100, seed=42) dataset = BaseConversionDataset(config) def is_valid_for_base(num_str: str, base: int) -> bool: @@ -113,12 +105,12 @@ def test_base_conversion_validity(): for i in range(len(dataset)): item = dataset[i] - assert is_valid_for_base(item["metadata"]["source_repr"], - item["metadata"]["source_base"]), \ - f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}" - assert is_valid_for_base(item["metadata"]["target_repr"], - item["metadata"]["target_base"]), \ - f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}" + assert is_valid_for_base( + item["metadata"]["source_repr"], item["metadata"]["source_base"] + ), f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}" + assert is_valid_for_base( + item["metadata"]["target_repr"], item["metadata"]["target_base"] + ), f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}" def test_base_conversion_special_bases():