mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
Fix(reasoning_gym/games/countdown): Resolve SymPy parsing conflict for 10+ input numbers (#514)
* Refactor expression generation and substitution logic Updated symbol naming and added safe replacement for expressions. * Add expr_str to return values in countdown.py Modified return statement to include the modified expression string. * Implement test for min_numbers exceeding 10 Add test for CountdownDataset with more than 10 numbers * Remove trailing-whitespace * Improve readability of CountdownDataset initialization Refactor CountdownDataset initialization for readability.
This commit is contained in:
parent
de2e89d21d
commit
7d68a6cc70
2 changed files with 29 additions and 4 deletions
|
|
@ -127,7 +127,7 @@ class CountdownDataset(ProceduralDataset):
|
|||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]
|
||||
|
||||
# Create symbols for building expression
|
||||
syms = symbols(f"x:{num_terms}")
|
||||
syms = symbols(f"x_{0}:{num_terms}")
|
||||
|
||||
# Build random expression
|
||||
expr = syms[0]
|
||||
|
|
@ -162,7 +162,23 @@ class CountdownDataset(ProceduralDataset):
|
|||
# Fallback to addition for zero
|
||||
expr = expr + syms[i]
|
||||
|
||||
return expr, numbers, syms
|
||||
# Safely replace symbols with numbers (to avoid name conflicts)
|
||||
expr_str = str(expr)
|
||||
|
||||
# Create a list of replacements: [(symbol_name, number_string), ...]
|
||||
replacements = []
|
||||
for i, sym in enumerate(syms):
|
||||
sym_name = str(sym)
|
||||
replacements.append((sym_name, str(numbers[i])))
|
||||
|
||||
# Sort by symbol name length in descending order (replace longer names first)
|
||||
replacements.sort(key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
# Perform the safe replacement
|
||||
for sym_name, num_str in replacements:
|
||||
expr_str = expr_str.replace(sym_name, num_str)
|
||||
|
||||
return expr, numbers, syms, expr_str
|
||||
|
||||
def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:
|
||||
"""Generate a valid expression and its result
|
||||
|
|
@ -175,14 +191,13 @@ class CountdownDataset(ProceduralDataset):
|
|||
max_attempts = 100
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
expr, numbers, syms = self._generate_candidate_expression(rng, num_terms)
|
||||
expr, numbers, syms, expr_str = self._generate_candidate_expression(rng, num_terms)
|
||||
|
||||
# Substitute actual numbers to get target
|
||||
subs = {sym: num for sym, num in zip(syms, numbers)}
|
||||
target = int(expr.subs(subs))
|
||||
|
||||
# Convert to string expression
|
||||
expr_str = str(expr)
|
||||
for i, sym in enumerate(syms):
|
||||
expr_str = expr_str.replace(str(sym), str(numbers[i]))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue