Refactor BaseConversion

This commit is contained in:
EduardDurech 2025-02-09 02:11:59 +00:00
parent 7dce30324b
commit c4f2f6386d
6 changed files with 743 additions and 264 deletions

View file

@ -1,109 +1,104 @@
"""Base conversion task generator"""
"""Base conversion exercise that converts numbers between different bases."""
from dataclasses import dataclass
from random import Random
from typing import Optional, Tuple
from typing import Dict, Any
from ..factory import ProceduralDataset, register_dataset
class BaseConversionExercise:
"""Exercise generator for base conversion problems."""
def __init__(self):
self.curriculum = None
@dataclass
class BaseConversionConfig:
"""Configuration for base conversion task generation"""
def generate(self, curriculum: Any) -> Dict[str, Any]:
"""
Generate a base conversion problem using the curriculum.
min_base: int = 2 # Minimum base (2=binary)
max_base: int = 16 # Maximum base (16=hex)
min_value: int = 0 # Minimum decimal value to convert
max_value: int = 1000 # Maximum decimal value to convert
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
Returns:
Dict containing:
- question: str (e.g. "Convert the binary number 1010 to hexadecimal")
- answer: str (the converted number in target base)
- metadata: dict with details (value, source_base, target_base, etc.)
"""
self.curriculum = curriculum
template = curriculum.get_template(curriculum.rng)
return template.eval(self, curriculum.rng)
def validate(self) -> None:
"""Validate configuration parameters"""
assert 2 <= self.min_base <= 36, "min_base must be between 2 and 36"
assert self.min_base <= self.max_base <= 36, "max_base must be between min_base and 36"
assert self.min_value >= 0, "min_value must be non-negative"
assert self.max_value > self.min_value, "max_value must be > min_value"
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse the template metadata into structured data.
class BaseConversionDataset(ProceduralDataset):
"""Generates base conversion tasks"""
def __init__(self, config: BaseConversionConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _format_base_name(self, base: int) -> str:
"""Get human-readable name for common bases"""
if base == 2:
return "binary"
elif base == 16:
return "hexadecimal"
else:
return f"base-{base}"
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
"""Generate random value and source/target bases"""
value = rng.randint(self.config.min_value, self.config.max_value)
# Choose source and target bases
source_base = rng.randint(self.config.min_base, self.config.max_base)
target_base = rng.randint(self.config.min_base, self.config.max_base)
while target_base == source_base: # Ensure different bases
target_base = rng.randint(self.config.min_base, self.config.max_base)
return value, source_base, target_base
def __getitem__(self, idx: int) -> dict:
"""Generate a single base conversion task"""
rng = Random(self.seed + idx)
value, source_base, target_base = self._generate_conversion(rng)
# Convert decimal to source base representation
if source_base == 16:
source_repr = format(value, "x")
elif source_base == 2:
source_repr = format(value, "b")
else:
# Manual conversion for other bases
n = value
digits = []
while n:
digits.append(int(n % source_base))
n //= source_base
source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
# Convert decimal to target base for answer
if target_base == 16:
target_repr = format(value, "x")
elif target_base == 2:
target_repr = format(value, "b")
else:
# Manual conversion for other bases
n = value
digits = []
while n:
digits.append(int(n % target_base))
n //= target_base
target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
source_name = self._format_base_name(source_base)
target_name = self._format_base_name(target_base)
# Add hint for bases > 10 about using lowercase letters
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
return {
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
"answer": target_repr,
"metadata": {
"decimal_value": value,
"source_base": source_base,
"target_base": target_base,
"source_repr": source_repr,
"target_repr": target_repr,
},
The metadata structure from the curriculum:
{
"source_value": {"val": str}, # e.g. "1010" or "a5"
"source_base": {"base": str}, # e.g. "binary" or "base-3"
"target_base": {"base": str, "hint": str}, # e.g. "hexadecimal" or "base-8" with optional hint
}
Returns:
Dictionary containing:
- source_value: str (value to convert)
- source_base: int (base to convert from)
- target_base: int (base to convert to)
"""
def parse_base_name(name: str) -> int:
"""Convert base name to numeric value."""
name = name.lower()
if name == "binary":
return 2
elif name == "octal":
return 8
elif name == "decimal":
return 10
elif name == "hexadecimal":
return 16
elif name.startswith("base-"):
return int(name[5:])
raise ValueError(f"Unknown base name: {name}")
register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig)
return {
"source_value": metadata["source_value"]["val"],
"source_base": parse_base_name(metadata["source_base"]["base"]),
"target_base": parse_base_name(metadata["target_base"]["base"])
}
def _evaluate_expression(self, parsed: Dict[str, Any]) -> str:
"""
Convert the number between bases.
Args:
parsed: Dictionary containing:
- source_base: int (base to convert from)
- target_base: int (base to convert to)
- source_value: str (value to convert)
Returns:
String representation of the number in target base
"""
try:
# Convert source value to decimal, handling letter digits
source_value = parsed["source_value"].lower()
decimal_value = 0
for digit in source_value:
if digit.isdigit():
digit_val = int(digit)
else:
digit_val = ord(digit) - ord('a') + 10
if digit_val >= parsed["source_base"]:
raise ValueError(f"Digit {digit} is invalid for base {parsed['source_base']}")
decimal_value = decimal_value * parsed["source_base"] + digit_val
# Convert decimal to target base
if decimal_value == 0:
return "0"
# Manual conversion for all bases
digits = []
n = decimal_value
while n:
digits.append(int(n % parsed["target_base"]))
n //= parsed["target_base"]
# Convert to string with letters for digits > 9
result = "".join(str(d) if d < 10 else chr(ord("a") + d - 10)
for d in reversed(digits))
return result
except ValueError as e:
return f"Error converting number: {str(e)}"