diff --git a/eval/eval.py b/eval/eval.py index 7ced20f0..bbfc585d 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -54,6 +54,7 @@ class AsyncOpenRouterEvaluator: "question": entry["question"], "expected_answer": entry["answer"], "model_answer": answer, + "full_model_response": response, "score": score, "metadata": entry["metadata"], } diff --git a/eval/r1/eval.py b/eval/r1/eval.py index 6f7a4a73..202646ac 100644 --- a/eval/r1/eval.py +++ b/eval/r1/eval.py @@ -104,6 +104,7 @@ class OpenRouterEvaluator: "question": entry["question"], "expected_answer": str(entry["answer"]), "model_answer": model_answer, + "full_model_response": response, "score": score, "metadata": str(entry["metadata"]), } diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 9f525011..ad9cccd3 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities: from .ab import ABConfig, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionDataset +from .binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .count_primes import CountPrimesConfig, CountPrimesDataset @@ -105,4 +106,6 @@ __all__ = [ "RottenOrangesDataset", "JugsConfig", "JugsDataset", + "BinaryAlternationConfig", + "BinaryAlternationDataset", ] diff --git a/reasoning_gym/algorithmic/binary_alternation.py b/reasoning_gym/algorithmic/binary_alternation.py new file mode 100644 index 00000000..ca204c6d --- /dev/null +++ b/reasoning_gym/algorithmic/binary_alternation.py @@ -0,0 +1,114 @@ +"""Minimum number of swaps to make a binary string alternating + +https://leetcode.com/problems/minimum-number-of-swaps-to-make-the-binary-string-alternating/description/ +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Given a binary string, return the minimum number of character swaps to make it alternating, or -1 if it is impossible. + +The string is called alternating if no two adjacent characters are equal. For example, the strings "010" and "1010" are alternating, while the string "0100" is not. + +Any two characters may be swapped, even if they are not adjacent. + +Example: +- Input: Determine the minimum number of swaps to make the following binary string alternating: 111000 +- Output: 1 + +Now, determine the minimum number of swaps to make the following binary string alternating: {string} +""" + + +@dataclass +class BinaryAlternationConfig: + """Configuration for Count Bits dataset generation""" + + min_n: int = 10 # Minimum number of bits in the binary string + max_n: int = 30 # Maximum number of bits in the binary string + p_solvable: float = 0.8 # Probability of generating a solvable sample + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.min_n, "Minimum number of bits must be at least 1" + assert self.min_n <= self.max_n, "Minimum number of bits must be <= maximum number of bits" + assert 0 <= self.p_solvable <= 1, "Probability of generating a 1 must be in [0, 1]" + + +class BinaryAlternationDataset(ProceduralDataset): + """Generates Binary Alternation exercises with configurable difficulty""" + + def __init__(self, config: BinaryAlternationConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_binary_string(self, rng: Random, solvable: bool) -> str: + n = rng.randint(self.config.min_n, self.config.max_n) + ones, zeros = n // 2, n // 2 + + # Check if we need to add an extra bit + if n % 2 == 1: + if rng.random() < 0.5: + ones += 1 + else: + zeros += 1 + + if not solvable: + if ones > zeros: + ones += 1 + elif ones < zeros: + zeros += 1 + else: + # Randomly add 2 bits of the same type + if rng.random() < 0.5: + ones += 2 + else: + zeros += 2 + + # Generate the string + string = ["1"] * ones + ["0"] * zeros + rng.shuffle(string) + return "".join(string) + + def _get_answer(self, string: str) -> int: + """Calculate the minimum number of swaps to make the string alternating""" + + def get_num_swaps(expected): + incorrect = 0 + for c in string: + if c != expected: + incorrect += 1 + expected = "1" if expected == "0" else "0" + return incorrect // 2 # number of swaps is half of incorrect positions + + ones, zeros = string.count("1"), string.count("0") + if abs(ones - zeros) > 1: + return -1 # impossible to make alternating + if ones > zeros: + return get_num_swaps("1") + elif ones < zeros: + return get_num_swaps("0") + else: + return min(get_num_swaps("0"), get_num_swaps("1")) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Bits question""" + rng = Random(self.seed + idx) + + solvable = rng.random() < self.config.p_solvable + string = self._get_binary_string(rng, solvable) + answer = self._get_answer(string) + + return { + "question": QUESTION_TEMPLATE.format(string=string), + "answer": str(answer), + "metadata": {"string": string, "solution": answer, "solvable": solvable}, + } + + +register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig) diff --git a/tests/test_binary_alternation.py b/tests/test_binary_alternation.py new file mode 100644 index 00000000..81a581e6 --- /dev/null +++ b/tests/test_binary_alternation.py @@ -0,0 +1,104 @@ +"""Tests for Binary Alternation questions generation""" + +import pytest + +from reasoning_gym.algorithmic.binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset + + +def test_binary_alternation_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(max_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(min_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(min_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(p_solvable=-0.01) # < 0 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryAlternationConfig(p_solvable=1.01) # > 0 not allowed + config.validate() + + +def test_binary_alternation_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = BinaryAlternationConfig(seed=42, size=10) + dataset1 = BinaryAlternationDataset(config) + dataset2 = BinaryAlternationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_binary_alternation_dataset_items(): + """Test basic properties of generated items""" + config = BinaryAlternationConfig(max_n=10, size=10, seed=42) + dataset = BinaryAlternationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "solution" in item["metadata"] + assert "solvable" in item["metadata"] + + solution = item["metadata"]["solution"] + string = item["metadata"]["string"] + solvable = item["metadata"]["solvable"] + + # Verify values + assert set(string) <= {"0", "1"} + if solution == -1: + assert not solvable + assert abs(string.count("1") - string.count("0")) > 1 + else: + assert solvable + assert abs(string.count("1") - string.count("0")) <= 1 + + +def test_binary_alternation_dataset_iteration(): + """Test that iteration respects dataset size""" + config = BinaryAlternationConfig(size=5, seed=42) + dataset = BinaryAlternationDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_binary_alternation_answer(): + """Verify the number of 1 bits in the binary representation of a number""" + config = BinaryAlternationConfig(size=5, seed=42) + dataset = BinaryAlternationDataset(config) + + # Impossible + string = "1110" + assert dataset._get_answer(string) == -1 + + # Already alternating + string = "10101" + assert dataset._get_answer(string) == 0 + + # One shot example + string = "111000" + assert dataset._get_answer(string) == 1