Fix(reasoning_gym/games/countdown): Resolve SymPy parsing conflict for 10+ input numbers (#514)

* Refactor expression generation and substitution logic

Updated symbol naming and added safe replacement for expressions.

* Add expr_str to return values in countdown.py

Modified return statement to include the modified expression string.

* Implement test for min_numbers exceeding 10

Add test for CountdownDataset with more than 10 numbers

* Remove trailing-whitespace

* Improve readability of CountdownDataset initialization

Refactor CountdownDataset initialization for readability.
This commit is contained in:
SII-Whereby 2025-12-15 19:05:38 +08:00 committed by GitHub
parent de2e89d21d
commit 7d68a6cc70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 4 deletions

View file

@ -127,7 +127,7 @@ class CountdownDataset(ProceduralDataset):
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)] numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]
# Create symbols for building expression # Create symbols for building expression
syms = symbols(f"x:{num_terms}") syms = symbols(f"x_{0}:{num_terms}")
# Build random expression # Build random expression
expr = syms[0] expr = syms[0]
@ -162,7 +162,23 @@ class CountdownDataset(ProceduralDataset):
# Fallback to addition for zero # Fallback to addition for zero
expr = expr + syms[i] expr = expr + syms[i]
return expr, numbers, syms # Safely replace symbols with numbers (to avoid name conflicts)
expr_str = str(expr)
# Create a list of replacements: [(symbol_name, number_string), ...]
replacements = []
for i, sym in enumerate(syms):
sym_name = str(sym)
replacements.append((sym_name, str(numbers[i])))
# Sort by symbol name length in descending order (replace longer names first)
replacements.sort(key=lambda x: len(x[0]), reverse=True)
# Perform the safe replacement
for sym_name, num_str in replacements:
expr_str = expr_str.replace(sym_name, num_str)
return expr, numbers, syms, expr_str
def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]: def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:
"""Generate a valid expression and its result """Generate a valid expression and its result
@ -175,14 +191,13 @@ class CountdownDataset(ProceduralDataset):
max_attempts = 100 max_attempts = 100
for attempt in range(max_attempts): for attempt in range(max_attempts):
try: try:
expr, numbers, syms = self._generate_candidate_expression(rng, num_terms) expr, numbers, syms, expr_str = self._generate_candidate_expression(rng, num_terms)
# Substitute actual numbers to get target # Substitute actual numbers to get target
subs = {sym: num for sym, num in zip(syms, numbers)} subs = {sym: num for sym, num in zip(syms, numbers)}
target = int(expr.subs(subs)) target = int(expr.subs(subs))
# Convert to string expression # Convert to string expression
expr_str = str(expr)
for i, sym in enumerate(syms): for i, sym in enumerate(syms):
expr_str = expr_str.replace(str(sym), str(numbers[i])) expr_str = expr_str.replace(str(sym), str(numbers[i]))

View file

@ -114,6 +114,16 @@ def test_edge_cases_2():
assert dataset.score_answer(answer=answer, entry=item) != 1.0 assert dataset.score_answer(answer=answer, entry=item) != 1.0
def test_countdown_more_numbers():
"""Test when min_numbers exceed 10"""
dataset = CountdownDataset(
CountdownConfig(min_numbers=11, max_numbers=11, shuffle=False, size=5, seed=42)
) # Set 11 engaged numbers for testing
for item in dataset:
assert item["metadata"]["target"] == int(eval(item["metadata"]["expression"]))
def test_countdown_game_randomization(): def test_countdown_game_randomization():
"""Test number randomization configuration""" """Test number randomization configuration"""
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing