diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 495a79c5..6087c289 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -12,6 +12,7 @@ from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset +from .number_format import NumberFormatConfig, NumberFormatDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset @@ -46,4 +47,6 @@ __all__ = [ "CountBitsDataset", "DiceConfig", "DiceDataset", + "NumberFormatConfig", + "NumberFormatDataset", ] diff --git a/reasoning_gym/arithmetic/number_format.py b/reasoning_gym/arithmetic/number_format.py new file mode 100644 index 00000000..e03d2bdc --- /dev/null +++ b/reasoning_gym/arithmetic/number_format.py @@ -0,0 +1,106 @@ +"""Choose largest number out of several represented in various formats.""" + +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Your task is to pick the largest/smallest number out of several options. + +Example +- Input: Pick the largest number of the following candidates: 857575.23 8.975554e+05 887,555.62 +- Output: 8.975554e+05 +- Explanation: + - Sorting the numbers written in various notations we get: 857575.23 < 887,555.62 < 8.975554e+05 + - Therefore, the largest number is 8.975554e+05 + +Now, pick the {size} number of the following candidates: {numbers} +""" + + +@dataclass +class NumberFormatConfig: + """Configuration for Count Bits dataset generation""" + + max_num_candidates: int = 5 # Maximum number of candidates + min_n: float = 1_000 # Lower bound for the numbers + max_n: float = 1_000_000_000 # Upper bound for the numbers + max_delta: int = 1_000 + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 2 <= self.max_num_candidates, "max_num_candidates must be at least 2" + assert 1 <= self.min_n, "min_n must be at least 1" + assert self.min_n < self.max_n, "min_n must be less than max_n" + assert 1 <= self.max_delta, "max_delta must be at least 1" + + +class NumberFormatDataset(ProceduralDataset): + """Generates Count Bits exercises with configurable difficulty""" + + def __init__(self, config: NumberFormatConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_candidates(self, rng: Random, num_candidates: int) -> list: + """Generate a list of candidates""" + base = round(rng.uniform(self.config.min_n, self.config.max_n), 3) + candidates = [base] + for _ in range(num_candidates - 1): + delta = round(rng.uniform(-self.config.max_delta, self.config.max_delta), 3) + candidates.append(base + delta) + return candidates + + def _transform_candidates(self, rng: Random, candidates: list[float]) -> list[str]: + """Randomly apply different number formats to the candidates""" + output = [] + for candidate in candidates: + format_type = rng.choice(["standard", "english", "scientific"]) + if format_type == "standard": + output.append(f"{candidate:f}") + elif format_type == "english": + output.append(f"{candidate:,}") + elif format_type == "scientific": + output.append(f"{candidate:.15e}") + return output + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["metadata"]["solution"] + if answer is not None and len(answer) > 0: + try: + answer = float(answer.strip().replace(",", "")) + if abs(answer - oracle_answer) < 1e-2: + return 1.0 + return 0.01 + except: + return 0.0 + return 0.0 + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Bits question""" + rng = Random(self.seed + idx) + + num_candidates = rng.randint(2, self.config.max_num_candidates) + candidates = self._get_candidates(rng, num_candidates) + formatted_candidates = self._transform_candidates(rng, candidates) + + size = rng.choice(["largest", "smallest"]) + answer = max(candidates) if size == "largest" else min(candidates) + + return { + "question": QUESTION_TEMPLATE.format(numbers=" ".join(formatted_candidates), size=size), + "answer": str(answer), + "metadata": { + "candidates": candidates, + "solution": answer, + "formatted_candidates": formatted_candidates, + "size": size, + }, + } + + +register_dataset("number_format", NumberFormatDataset, NumberFormatConfig) diff --git a/tests/test_number_format.py b/tests/test_number_format.py new file mode 100644 index 00000000..882f38aa --- /dev/null +++ b/tests/test_number_format.py @@ -0,0 +1,121 @@ +"""Tests for Number Format questions generation""" + +import pytest + +from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatDataset + + +def test_number_format_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_num_candidates=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_num_candidates=1) # One not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=10, max_n=5) # min > max + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_delta=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_delta=0) # Zero not allowed + config.validate() + + +def test_number_format_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = NumberFormatConfig(seed=42, size=10) + dataset1 = NumberFormatDataset(config) + dataset2 = NumberFormatDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_number_format_dataset_items(): + """Test basic properties of generated items""" + config = NumberFormatConfig(min_n=1_000, max_n=10_000, max_delta=1, size=10, seed=42) + dataset = NumberFormatDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "candidates" in item["metadata"] + assert "formatted_candidates" in item["metadata"] + assert "size" in item["metadata"] + assert "solution" in item["metadata"] + + candidates = item["metadata"]["candidates"] + formatted_candidates = item["metadata"]["formatted_candidates"] + size = item["metadata"]["size"] + solution = item["metadata"]["solution"] + + # Verify values + assert len(candidates) >= 2 + assert all(999 <= c <= 10_001 for c in candidates) # boundaries +- delta + assert len(candidates) == len(formatted_candidates) + assert size in ["largest", "smallest"] + assert solution in candidates + + +def test_number_format_dataset_iteration(): + """Test that iteration respects dataset size""" + config = NumberFormatConfig(size=5, seed=42) + dataset = NumberFormatDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_number_format_answer(): + """Verify the solution scoring""" + config = NumberFormatConfig(size=5, seed=42) + dataset = NumberFormatDataset(config) + + entry = {"metadata": {"solution": 54245.32}} + + # Correct answer (plain) + model_answer = "54245.32" + assert dataset.score_answer(model_answer, entry) == 1.0 + + # Correct answer (English) + model_answer = "54,245.32" + assert dataset.score_answer(model_answer, entry) == 1.0 + + # Correct answer (scientific) + assert dataset.score_answer("5.424532e+04", entry) == 1.0 + + # Incorrect answer (diff larger than 1e-2) + model_answer = "54245.9" + assert dataset.score_answer(model_answer, entry) == 0.01 + + # Answer is null + model_answer = None + assert dataset.score_answer(model_answer, entry) == 0.0 + + # Answer is unparsable + model_answer = "test" + assert dataset.score_answer(model_answer, entry) == 0.0