diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 44a92ec6..480586b9 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -5,14 +5,17 @@ Game tasks for training reasoning capabilities: - Strategy games """ +from .countdown_game import CountdownGameConfig, CountdownGameDataset from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .sudoku import SudokuConfig, SudokuDataset __all__ = [ + "CountdownGameConfig", + "CountdownGameDataset", "MiniSudokuConfig", "MiniSudokuDataset", - "SudokuConfig", + "SudokuConfig", "SudokuDataset", "MazeConfig", "MazeDataset", diff --git a/reasoning_gym/games/countdown_game.py b/reasoning_gym/games/countdown_game.py new file mode 100644 index 00000000..7411ba03 --- /dev/null +++ b/reasoning_gym/games/countdown_game.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass +from random import Random +from typing import List, Optional, Tuple + +import sympy +from sympy import Symbol, symbols + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class CountdownGameConfig: + """Configuration for Countdown Number Game task generation""" + + min_numbers: int = 4 # Minimum numbers to provide + max_numbers: int = 6 # Maximum numbers to provide + min_value: int = 1 # Minimum value for source numbers + max_value: int = 100 # Maximum value for source numbers + min_target: int = 100 # Minimum target value + max_target: int = 999 # Maximum target value + operators: tuple = ("+", "-", "*", "/") # Allowed operators + randomize_numbers: bool = True # Whether to randomize the order of source numbers + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_numbers > 1, "min_numbers must be greater than 1" + assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers" + assert self.min_value > 0, "min_value must be positive" + assert self.max_value >= self.min_value, "max_value must be >= min_value" + assert self.min_target > 0, "min_target must be positive" + assert self.max_target >= self.min_target, "max_target must be >= min_target" + assert len(self.operators) > 0, "must specify at least one operator" + assert all(op in ("+", "-", "*", "/") for op in self.operators), "invalid operator specified" + + +class CountdownGameDataset(ProceduralDataset): + """Generates Countdown Number Game tasks""" + + def __init__(self, config: CountdownGameConfig): + self._prompt_templates = [ + "Using the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.", + "Find a way to make {target} using some or all of these numbers: {numbers}.\nEach number can only be used once.", + "Calculate {target} using the numbers {numbers}.\nEach number may be used at most once.", + ] + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Countdown Game task + + Returns: + dict with keys: + - question: str, the task description with numbers and target + - answer: str, one possible solution expression + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + + # Generate a valid expression and its result + expression, numbers, target = self._generate_expression(rng) + + # Optionally randomize the order of numbers + if self.config.randomize_numbers: + rng.shuffle(numbers) + + numbers_str = ", ".join(map(str, numbers)) + + return { + "question": rng.choice(self._prompt_templates).format( + numbers=numbers_str, + target=target + ), + "answer": expression, + "metadata": { + "numbers": numbers, + "target": target, + "expression": expression, + }, + } + + def _generate_expression(self, rng: Random) -> Tuple[str, List[int], int]: + """Generate a valid expression and its result + + Returns: + Tuple of (expression string, list of numbers used, target value) + """ + num_terms = rng.randint(self.config.min_numbers, self.config.max_numbers) + + # Generate random numbers + numbers = [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_terms)] + + # Create symbols for building expression + syms = symbols(f"x:{num_terms}") + + # Build random expression + expr = syms[0] + used_nums = [numbers[0]] + + for i in range(1, num_terms): + op = rng.choice(self.config.operators) + if op == "+": + expr = expr + syms[i] + elif op == "-": + expr = expr - syms[i] + elif op == "*": + expr = expr * syms[i] + else: # division + # Ensure division results in integer + if numbers[i] != 0: # Avoid division by zero + # Try to find a number that divides evenly + remaining = [n for n in numbers[i:] if n != 0] + if remaining: + div = rng.choice(remaining) + numbers[i] = div + expr = expr / syms[i] + else: + # Fallback to multiplication + expr = expr * syms[i] + else: + # Fallback to multiplication + expr = expr * syms[i] + used_nums.append(numbers[i]) + + # Substitute actual numbers to get target + subs = {sym: num for sym, num in zip(syms, numbers)} + target = int(expr.subs(subs)) + + # Convert to string expression + expr_str = str(expr) + for i, sym in enumerate(syms): + expr_str = expr_str.replace(str(sym), str(numbers[i])) + + return expr_str, numbers, target + + +# Register the dataset +register_dataset("countdown_game", CountdownGameDataset, CountdownGameConfig) diff --git a/tests/test_countdown_game.py b/tests/test_countdown_game.py new file mode 100644 index 00000000..a7ae5ed1 --- /dev/null +++ b/tests/test_countdown_game.py @@ -0,0 +1,87 @@ +import pytest + +from reasoning_gym.games.countdown_game import CountdownGameConfig, CountdownGameDataset + + +def test_countdown_game_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CountdownGameConfig(min_numbers=1) # Too few numbers + config.validate() + + with pytest.raises(AssertionError): + config = CountdownGameConfig(min_numbers=4, max_numbers=3) # Invalid range + config.validate() + + with pytest.raises(AssertionError): + config = CountdownGameConfig(operators=["^"]) # Invalid operator + config.validate() + + +def test_countdown_game_deterministic(): + """Test that dataset generates same items with same seed""" + config = CountdownGameConfig(seed=42, size=10) + dataset1 = CountdownGameDataset(config) + dataset2 = CountdownGameDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_countdown_game_items(): + """Test basic properties of generated items""" + config = CountdownGameConfig( + min_numbers=3, + max_numbers=5, + min_value=1, + max_value=20, # Small numbers for testing + min_target=10, + max_target=100, + size=50, + seed=42 + ) + dataset = CountdownGameDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "numbers" in item["metadata"] + assert "target" in item["metadata"] + assert "expression" in item["metadata"] + + # Verify number of source numbers is within config range + assert config.min_numbers <= len(item["metadata"]["numbers"]) <= config.max_numbers + + # Verify target is within config range + assert config.min_target <= item["metadata"]["target"] <= config.max_target + + # Verify all numbers are within config range + assert all(config.min_value <= n <= config.max_value for n in item["metadata"]["numbers"]) + + +def test_countdown_game_randomization(): + """Test number randomization configuration""" + config = CountdownGameConfig( + min_numbers=4, + max_numbers=4, # Fixed size for testing + randomize_numbers=False, + size=10, + seed=42 + ) + + # Without randomization, numbers should appear in same order + dataset = CountdownGameDataset(config) + first_item = dataset[0] + expr_nums = [int(n) for n in first_item["metadata"]["expression"].replace("(","").replace(")","").split(" ") if n.isdigit()] + assert expr_nums == first_item["metadata"]["numbers"] + + # With randomization, numbers might appear in different order + config.randomize_numbers = True + dataset = CountdownGameDataset(config) + first_item = dataset[0] + expr_nums = [int(n) for n in first_item["metadata"]["expression"].replace("(","").replace(")","").split(" ") if n.isdigit()] + assert sorted(expr_nums) == sorted(first_item["metadata"]["numbers"])