mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Fix(reasoning_gym/games/countdown): Resolve SymPy parsing conflict for 10+ input numbers (#514)
* Refactor expression generation and substitution logic Updated symbol naming and added safe replacement for expressions. * Add expr_str to return values in countdown.py Modified return statement to include the modified expression string. * Implement test for min_numbers exceeding 10 Add test for CountdownDataset with more than 10 numbers * Remove trailing-whitespace * Improve readability of CountdownDataset initialization Refactor CountdownDataset initialization for readability.
This commit is contained in:
parent
de2e89d21d
commit
7d68a6cc70
2 changed files with 29 additions and 4 deletions
|
|
@ -127,7 +127,7 @@ class CountdownDataset(ProceduralDataset):
|
|||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]
|
||||
|
||||
# Create symbols for building expression
|
||||
syms = symbols(f"x:{num_terms}")
|
||||
syms = symbols(f"x_{0}:{num_terms}")
|
||||
|
||||
# Build random expression
|
||||
expr = syms[0]
|
||||
|
|
@ -162,7 +162,23 @@ class CountdownDataset(ProceduralDataset):
|
|||
# Fallback to addition for zero
|
||||
expr = expr + syms[i]
|
||||
|
||||
return expr, numbers, syms
|
||||
# Safely replace symbols with numbers (to avoid name conflicts)
|
||||
expr_str = str(expr)
|
||||
|
||||
# Create a list of replacements: [(symbol_name, number_string), ...]
|
||||
replacements = []
|
||||
for i, sym in enumerate(syms):
|
||||
sym_name = str(sym)
|
||||
replacements.append((sym_name, str(numbers[i])))
|
||||
|
||||
# Sort by symbol name length in descending order (replace longer names first)
|
||||
replacements.sort(key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
# Perform the safe replacement
|
||||
for sym_name, num_str in replacements:
|
||||
expr_str = expr_str.replace(sym_name, num_str)
|
||||
|
||||
return expr, numbers, syms, expr_str
|
||||
|
||||
def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:
|
||||
"""Generate a valid expression and its result
|
||||
|
|
@ -175,14 +191,13 @@ class CountdownDataset(ProceduralDataset):
|
|||
max_attempts = 100
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
expr, numbers, syms = self._generate_candidate_expression(rng, num_terms)
|
||||
expr, numbers, syms, expr_str = self._generate_candidate_expression(rng, num_terms)
|
||||
|
||||
# Substitute actual numbers to get target
|
||||
subs = {sym: num for sym, num in zip(syms, numbers)}
|
||||
target = int(expr.subs(subs))
|
||||
|
||||
# Convert to string expression
|
||||
expr_str = str(expr)
|
||||
for i, sym in enumerate(syms):
|
||||
expr_str = expr_str.replace(str(sym), str(numbers[i]))
|
||||
|
||||
|
|
|
|||
|
|
@ -114,6 +114,16 @@ def test_edge_cases_2():
|
|||
assert dataset.score_answer(answer=answer, entry=item) != 1.0
|
||||
|
||||
|
||||
def test_countdown_more_numbers():
|
||||
"""Test when min_numbers exceed 10"""
|
||||
dataset = CountdownDataset(
|
||||
CountdownConfig(min_numbers=11, max_numbers=11, shuffle=False, size=5, seed=42)
|
||||
) # Set 11 engaged numbers for testing
|
||||
|
||||
for item in dataset:
|
||||
assert item["metadata"]["target"] == int(eval(item["metadata"]["expression"]))
|
||||
|
||||
|
||||
def test_countdown_game_randomization():
|
||||
"""Test number randomization configuration"""
|
||||
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue