diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 480586b9..cf083ba4 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -5,17 +5,17 @@ Game tasks for training reasoning capabilities: - Strategy games """ -from .countdown_game import CountdownGameConfig, CountdownGameDataset +from .countdown import CountdownConfig, CountdownDataset from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .sudoku import SudokuConfig, SudokuDataset __all__ = [ - "CountdownGameConfig", - "CountdownGameDataset", + "CountdownConfig", + "CountdownDataset", "MiniSudokuConfig", "MiniSudokuDataset", - "SudokuConfig", + "SudokuConfig", "SudokuDataset", "MazeConfig", "MazeDataset", diff --git a/reasoning_gym/games/countdown_game.py b/reasoning_gym/games/countdown.py similarity index 85% rename from reasoning_gym/games/countdown_game.py rename to reasoning_gym/games/countdown.py index 7411ba03..edec6c38 100644 --- a/reasoning_gym/games/countdown_game.py +++ b/reasoning_gym/games/countdown.py @@ -9,7 +9,7 @@ from ..factory import ProceduralDataset, register_dataset @dataclass -class CountdownGameConfig: +class CountdownConfig: """Configuration for Countdown Number Game task generation""" min_numbers: int = 4 # Minimum numbers to provide @@ -19,7 +19,7 @@ class CountdownGameConfig: min_target: int = 100 # Minimum target value max_target: int = 999 # Maximum target value operators: tuple = ("+", "-", "*", "/") # Allowed operators - randomize_numbers: bool = True # Whether to randomize the order of source numbers + shuffle: bool = True # Whether to shuffle the order of source numbers seed: Optional[int] = None size: int = 500 @@ -35,10 +35,10 @@ class CountdownGameConfig: assert all(op in ("+", "-", "*", "/") for op in self.operators), "invalid operator specified" -class CountdownGameDataset(ProceduralDataset): +class CountdownDataset(ProceduralDataset): """Generates Countdown Number Game tasks""" - def __init__(self, config: CountdownGameConfig): + def __init__(self, config: CountdownConfig): self._prompt_templates = [ "Using the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.", "Find a way to make {target} using some or all of these numbers: {numbers}.\nEach number can only be used once.", @@ -56,21 +56,18 @@ class CountdownGameDataset(ProceduralDataset): - metadata: dict with generation parameters """ rng = Random(self.seed + idx) - + # Generate a valid expression and its result expression, numbers, target = self._generate_expression(rng) - + # Optionally randomize the order of numbers - if self.config.randomize_numbers: + if self.config.shuffle: rng.shuffle(numbers) - + numbers_str = ", ".join(map(str, numbers)) - + return { - "question": rng.choice(self._prompt_templates).format( - numbers=numbers_str, - target=target - ), + "question": rng.choice(self._prompt_templates).format(numbers=numbers_str, target=target), "answer": expression, "metadata": { "numbers": numbers, @@ -81,23 +78,22 @@ class CountdownGameDataset(ProceduralDataset): def _generate_expression(self, rng: Random) -> Tuple[str, List[int], int]: """Generate a valid expression and its result - + Returns: Tuple of (expression string, list of numbers used, target value) """ num_terms = rng.randint(self.config.min_numbers, self.config.max_numbers) - + # Generate random numbers - numbers = [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_terms)] - + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)] + # Create symbols for building expression syms = symbols(f"x:{num_terms}") - + # Build random expression expr = syms[0] used_nums = [numbers[0]] - + for i in range(1, num_terms): op = rng.choice(self.config.operators) if op == "+": @@ -116,24 +112,24 @@ class CountdownGameDataset(ProceduralDataset): numbers[i] = div expr = expr / syms[i] else: - # Fallback to multiplication - expr = expr * syms[i] + # Fallback to subtraction + expr = expr - syms[i] else: - # Fallback to multiplication - expr = expr * syms[i] + # Fallback to addition + expr = expr + syms[i] used_nums.append(numbers[i]) - + # Substitute actual numbers to get target subs = {sym: num for sym, num in zip(syms, numbers)} target = int(expr.subs(subs)) - + # Convert to string expression expr_str = str(expr) for i, sym in enumerate(syms): expr_str = expr_str.replace(str(sym), str(numbers[i])) - + return expr_str, numbers, target # Register the dataset -register_dataset("countdown_game", CountdownGameDataset, CountdownGameConfig) +register_dataset("countdown", CountdownDataset, CountdownConfig) diff --git a/tests/test_countdown_game.py b/tests/test_countdown.py similarity index 63% rename from tests/test_countdown_game.py rename to tests/test_countdown.py index a7ae5ed1..d015d143 100644 --- a/tests/test_countdown_game.py +++ b/tests/test_countdown.py @@ -1,28 +1,28 @@ import pytest -from reasoning_gym.games.countdown_game import CountdownGameConfig, CountdownGameDataset +from reasoning_gym.games.countdown import CountdownConfig, CountdownDataset def test_countdown_game_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = CountdownGameConfig(min_numbers=1) # Too few numbers + config = CountdownConfig(min_numbers=1) # Too few numbers config.validate() with pytest.raises(AssertionError): - config = CountdownGameConfig(min_numbers=4, max_numbers=3) # Invalid range + config = CountdownConfig(min_numbers=4, max_numbers=3) # Invalid range config.validate() with pytest.raises(AssertionError): - config = CountdownGameConfig(operators=["^"]) # Invalid operator + config = CountdownConfig(operators=["^"]) # Invalid operator config.validate() def test_countdown_game_deterministic(): """Test that dataset generates same items with same seed""" - config = CountdownGameConfig(seed=42, size=10) - dataset1 = CountdownGameDataset(config) - dataset2 = CountdownGameDataset(config) + config = CountdownConfig(seed=42, size=10) + dataset1 = CountdownDataset(config) + dataset2 = CountdownDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -30,7 +30,7 @@ def test_countdown_game_deterministic(): def test_countdown_game_items(): """Test basic properties of generated items""" - config = CountdownGameConfig( + config = CountdownConfig( min_numbers=3, max_numbers=5, min_value=1, @@ -38,50 +38,48 @@ def test_countdown_game_items(): min_target=10, max_target=100, size=50, - seed=42 + seed=42, ) - dataset = CountdownGameDataset(config) + dataset = CountdownDataset(config) for item in dataset: assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata contains required fields assert "numbers" in item["metadata"] assert "target" in item["metadata"] assert "expression" in item["metadata"] - + # Verify number of source numbers is within config range assert config.min_numbers <= len(item["metadata"]["numbers"]) <= config.max_numbers - + # Verify target is within config range assert config.min_target <= item["metadata"]["target"] <= config.max_target - + # Verify all numbers are within config range assert all(config.min_value <= n <= config.max_value for n in item["metadata"]["numbers"]) def test_countdown_game_randomization(): """Test number randomization configuration""" - config = CountdownGameConfig( - min_numbers=4, - max_numbers=4, # Fixed size for testing - randomize_numbers=False, - size=10, - seed=42 - ) - + config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing + # Without randomization, numbers should appear in same order - dataset = CountdownGameDataset(config) + dataset = CountdownDataset(config) first_item = dataset[0] - expr_nums = [int(n) for n in first_item["metadata"]["expression"].replace("(","").replace(")","").split(" ") if n.isdigit()] + expr_nums = [ + int(n) for n in first_item["metadata"]["expression"].replace("(", "").replace(")", "").split(" ") if n.isdigit() + ] assert expr_nums == first_item["metadata"]["numbers"] - + # With randomization, numbers might appear in different order - config.randomize_numbers = True - dataset = CountdownGameDataset(config) + config.shuffle = True + dataset = CountdownDataset(config) first_item = dataset[0] - expr_nums = [int(n) for n in first_item["metadata"]["expression"].replace("(","").replace(")","").split(" ") if n.isdigit()] + expr_nums = [ + int(n) for n in first_item["metadata"]["expression"].replace("(", "").replace(")", "").split(" ") if n.isdigit() + ] assert sorted(expr_nums) == sorted(first_item["metadata"]["numbers"])