From 17088e9b42cc1e5aa515687ff32aec0f0ecb78e7 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Fri, 21 Feb 2025 12:02:41 +0100 Subject: [PATCH] add bitwise arithmetic --- .../arithmetic/bitwise_arithmetic.py | 153 ++++++++++++++++++ tests/test_bitwise_arithmetic.py | 54 +++++++ 2 files changed, 207 insertions(+) create mode 100644 reasoning_gym/arithmetic/bitwise_arithmetic.py create mode 100644 tests/test_bitwise_arithmetic.py diff --git a/reasoning_gym/arithmetic/bitwise_arithmetic.py b/reasoning_gym/arithmetic/bitwise_arithmetic.py new file mode 100644 index 00000000..1d90b0f5 --- /dev/null +++ b/reasoning_gym/arithmetic/bitwise_arithmetic.py @@ -0,0 +1,153 @@ +import ast +from dataclasses import dataclass +from random import Random +from typing import Any, Dict, List, Optional + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class BitwiseArithmeticConfig: + """Configuration for Bitwise arithmetic dataset generation""" + + difficulty: int = 2 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert 0 < self.difficulty, "difficulty must be gt 0" + assert 10 >= self.difficulty, "difficulty must be lte 10" + + +def generate_expression(rng, max_depth): + """ + Recursively generate a random arithmetic expression that includes + standard arithmetic (+, -, *) and bitwise shifting (<<, >>) operators. + All numbers are represented in hexadecimal format as multi-byte values. + + Parameters: + max_depth (int): Maximum depth of nested expressions. + + Returns: + str: A string representing the generated expression. + """ + # Base case: return a random multi-byte number in hex (0x100 to 0xFFFF). + if max_depth <= 0: + return hex(rng.randint(0x100, 0xFFFF)) + + # Occasionally return a simple hex number even if max_depth > 0. + if rng.random() < 0.01: + return hex(rng.randint(0x100, 0xFFFF)) + + # Choose a random operator. + operators = ["+", "-", "*", "<<", ">>"] + op = rng.choice(operators) + + # Generate left and right subexpressions. + left_expr = generate_expression(rng, max_depth - 1) + right_expr = generate_expression(rng, max_depth - 1) + + # For bitwise shift operations, keep the right operand small (in hex). + if op in ["<<", ">>"]: + right_expr = hex(rng.randint(0, 3)) + + return f"({left_expr} {op} {right_expr})" + + +def generate_problem(rng, difficulty=1): + """ + Generate a random arithmetic problem involving multi-byte hexadecimal numbers. + + The 'difficulty' parameter controls the complexity: + - Lower difficulty produces a shallower expression. + - Higher difficulty produces a more deeply nested expression. + + Parameters: + difficulty (int): The difficulty level (1 = simplest; higher values = more complex). + + Returns: + tuple: (problem_str, correct_answer) + - problem_str (str): The generated arithmetic expression (with hex numbers). + - correct_answer (str): The evaluated result, formatted as a hex string. + """ + max_depth = max(1, difficulty) + problem_str = generate_expression(rng, max_depth) + correct_value = eval(problem_str) + correct_answer = hex(correct_value) + + return problem_str, correct_answer + + +def verify_solution(problem, user_solution): + """ + Verify if the provided solution is correct for the given problem. + + Parameters: + problem (str): The arithmetic expression (with hex numbers). + user_solution (str or int): The user's answer, either as a hex string (e.g., "0xa") + or an integer. + + Returns: + bool: True if the user's answer matches the evaluated result, else False. + """ + try: + correct_value = eval(problem) + user_value = int(str(user_solution), 0) + except Exception as e: + return False + + return correct_value == user_value + + +class BitwiseArithmeticDataset(ProceduralDataset): + """Dataset that generates basic tasks using bitwise arithmetic and proper operator precedence.""" + + def __init__(self, config: BitwiseArithmeticConfig) -> None: + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Generate a single arithmetic task. + + Returns: + dict: Contains: + - 'question': The formatted arithmetic expression as a string. + - 'answer': The computed hexidecimal result. + - 'metadata': Additional metadata. + """ + # Create a deterministic RNG from base seed and index. + rng: Random = Random(self.seed + idx if self.seed is not None else None) + + problem, answer = generate_problem( + rng, + self.config.difficulty, + ) + problem_str = f"Please solve this problem. Reply only with the final hexidecimal value.\n" + problem + + return {"question": problem_str, "answer": answer, "metadata": {"problem": problem}} + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """ + Compares the user's answer (converted to Bitwise) with the correct answer. + Instead of requiring exact equality, we allow an error up to one unit in the + least significant digit as determined by the level of precision (max_num_Bitwise_places). + + Returns: + float: 1.0 if the user's answer is within tolerance; otherwise, 0.01. + """ + if answer is None: + return 0.0 + + try: + solved = verify_solution(entry["metadata"]["problem"], answer) + if solved: + return 1.0 + except Exception: + return 0.01 + + return 0.01 + + +# Register the dataset with the factory. +register_dataset("Bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig) diff --git a/tests/test_bitwise_arithmetic.py b/tests/test_bitwise_arithmetic.py new file mode 100644 index 00000000..2a742100 --- /dev/null +++ b/tests/test_bitwise_arithmetic.py @@ -0,0 +1,54 @@ +import pytest + +from reasoning_gym.arithmetic.bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset + + +def test_bitwise_arithmetic(): + """Test basic properties and solution of generated items""" + + # Easy + config = BitwiseArithmeticConfig( + seed=42, + size=2000, + difficulty=1, + ) + dataset = BitwiseArithmeticDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + assert dataset.score_answer(answer="kitty cat", entry=item) == 0.01 + + config = BitwiseArithmeticConfig( + seed=42, + size=100, + difficulty=5, + ) + dataset = BitwiseArithmeticDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + + config = BitwiseArithmeticConfig( + seed=42, + size=100, + difficulty=10, + ) + dataset = BitwiseArithmeticDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0