Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent bf24999bb0
commit b2904ccab9
106 changed files with 403 additions and 507 deletions

View file

@ -103,6 +103,10 @@ class ComplexArithmeticDataset(ProceduralDataset):
# Normalize the answer string by removing spaces and converting to lowercase
answer = answer.replace(" ", "").lower()
# remove brackets
while len(answer) > 1 and answer[0] == "(" and answer[-1] == ")":
answer = answer[1:-1]
# Convert mathematical notation 'i' to Python's 'j' for complex numbers
answer = answer.replace("i", "j")

View file

@ -77,9 +77,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
"Evaluate the indefinite integral: ∫ {integrand} dx",
]
self.added_instruction = """
In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number).
"""
def _get_outer_constant(self, rng: random.Random) -> int:
@ -245,7 +246,7 @@ In addition, when doing calculation, use the following instructions together wit
"""Determine if the solution provided solves the problem"""
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
if isinstance(answer, str):
try:
var = metadata["variable"]
x = sympy.Symbol(var)
@ -258,12 +259,8 @@ In addition, when doing calculation, use the following instructions together wit
# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -27,8 +27,9 @@ class PolynomialEquationsConfig:
seed: Optional[int] = None
size: int = 500
# reward function hyperparameters
penalty_missing_factor = 0.1
penalty_extra_factor = 0.05
penalty_missing_factor = 0.5
penalty_extra_factor = 0.5
exp_distance_factor = -10.0
def validate(self) -> None:
"""Validate configuration parameters."""
@ -62,12 +63,15 @@ class PolynomialEquationsDataset(ProceduralDataset):
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
]
self.added_instruction = """
In solving the equations, please abide by the following instruction:
## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc.
## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005".
## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "".
## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers.
## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p.
In solving equations, please follow these instructions:
1. Provide all answers as comma-separated decimal values. For example: "-0.3773, 0.4005"
2. For solutions that can be expressed in exact form (like "u = 2 + sqrt(4560)/172" and "u = 2 - sqrt(4560)/172"), convert them to decimal form in your final answer.
3. If there are no real values that satisfy the equation, report your answer as an empty string: ""
4. Format your answer based on the number of solutions:
- For 1 solution: a single decimal number
- For 2 solutions: two comma-separated decimal numbers
- For 3 or more solutions: all values as comma-separated decimal numbers
5. Round all decimal values to 4 decimal places (rounding down when the 5th decimal place is 5 or greater).
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -238,7 +242,7 @@ In solving the equations, please abide by the following instruction:
# Remove matched oracle solution
oracle_solutions.pop(matched_distance_index)
# Exponential decay reward
total_reward += math.exp(-matched_distance)
total_reward += math.exp(matched_distance * self.config.exp_distance_factor)
else:
# Extra predicted solution
extra_solutions += 1

View file

@ -69,9 +69,9 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
"Calculate the following: {polynomial_expr}",
]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers.
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -106,10 +106,9 @@ In addition, When doing calculation, Use the following instructions together wit
return {
"question": question,
"answer": product,
"answer": str(product),
"metadata": {
"polynomial_expr": str(polynomial_expr),
"result": str(product),
"variables": list(product.free_symbols),
},
}
@ -147,21 +146,16 @@ In addition, When doing calculation, Use the following instructions together wit
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
target_poly = sp.parse_expr(entry["answer"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if predicted_poly == target_poly:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except Exception:
reward = 0.01
reward = 0.0
return reward

View file

@ -42,9 +42,9 @@ class SimpleIntegrationDataset(ProceduralDataset):
"Evaluate the indefinite integral: ∫ {integrand} dx",
]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -103,12 +103,8 @@ In addition, When doing calculation, Use the following instructions together wit
# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -130,12 +130,9 @@ Return the final state of the program.
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer != entry["answer"]:
return 0.01
else:
if answer == entry["answer"]:
return 1.0 # Yay
return 0.0
# Register the dataset

View file

@ -108,9 +108,9 @@ class BinaryMatrixDataset(ProceduralDataset):
# check if answer is python list of lists
answer = self._matrix_to_str(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
return 0.0
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -200,7 +200,7 @@ class CryptarithmDataset(ProceduralDataset):
Returns:
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
correct_mapping = {}

View file

@ -106,7 +106,7 @@ class GameOfLifeDataset(ProceduralDataset):
ans_arr = json.loads(answer)
correct_arr = json.loads(entry["answer"])
except Exception:
return 0.01
return 0.0
total_cells = 0
correct_cells = 0

View file

@ -228,12 +228,13 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
try:
danswer = json.loads(answer)
solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer)
if not solved:
return 0.01 # json was parsable but solution incorrect
else:
if solved:
return 1.0 # Yay
else:
return 0.01 # json parsable
except Exception:
return 0.0
pass
return 0.0
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)

View file

@ -95,7 +95,7 @@ class GroupAnagramsDataset(ProceduralDataset):
if answer_str == oracle_str:
reward = 1.0
else:
reward = 0.01
reward = 0.01 # json parsable
except Exception:
reward = 0.0
return reward

View file

@ -303,11 +303,11 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
danswer = json.loads(answer)
valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer)
if not valid:
return 0.01
return 0.01 # json parsable
else:
return 1.0 # Yay
except Exception as e:
return 0.01
return 0.0
register_dataset("jugs", JugsDataset, JugsConfig)

View file

@ -116,7 +116,7 @@ class LetterJumbleDataset(ProceduralDataset):
# Each word in the expected answer is worth an equal fraction of 1.0
total_words = len(expected_words)
score_per_word = 1.0 / total_words if total_words else 0
score_per_word = 1.0 / total_words if total_words > 0 else 0
# Calculate scores word by word
scores = []
@ -142,18 +142,16 @@ class LetterJumbleDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
oracle_answer = entry["answer"].strip().lower()
if answer:
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
return 0.01
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)

View file

@ -144,8 +144,6 @@ class ManipulateMatrixDataset(ProceduralDataset):
if oracle_answer in answer:
return len(oracle_answer) / len(answer)
else:
return 0.01
return 0.0

View file

@ -92,14 +92,14 @@ class PalindromeDataset(ProceduralDataset):
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.01.
- An empty string gives 0.0
- None gives 0.0.
"""
if answer is None or not isinstance(answer, str):
return 0.0 # No answer given
if answer == "":
return 0.01
return 0.0
metadata = entry["metadata"]
answer = answer.strip().lower()

View file

@ -95,9 +95,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
if answer == oracle:
return 1.0
return 0.01
except Exception:
return 0.0
pass
return 0.0
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:

View file

@ -80,7 +80,7 @@ class PoolMatrixDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer based on the metadata"""
if not answer:
if not isinstance(answer, str):
return 0.0
reward = 0.0
@ -91,8 +91,6 @@ class PoolMatrixDataset(ProceduralDataset):
reward = 1.0
elif oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except Exception:
pass
return reward

View file

@ -108,14 +108,12 @@ class RansomNoteDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if isinstance(answer, str):
s_answer = answer.strip()
if s_answer == str(entry["answer"]):
return 1.0
s_answer = answer.strip()
if not s_answer == str(entry["answer"]):
return 0.01
else:
return 1.0
return 0.0
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)

View file

@ -110,7 +110,7 @@ class SentenceReorderingDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -52,14 +52,14 @@ class SpellBackwardDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
try:
if expected_answer.lower() == answer.lower():
reward = 1.0
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -126,11 +126,9 @@ class SpiralMatrixDataset(ProceduralDataset):
try:
answer = " ".join(str(item) for item in eval(answer))
if answer == oracle_answer:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0

View file

@ -75,7 +75,7 @@ class StringInsertionDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
if answer == oracle_answer:
return 1.0
else:
@ -83,9 +83,9 @@ class StringInsertionDataset(ProceduralDataset):
# check if answer is python list of characters
answer = "".join(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -221,8 +221,8 @@ class WordLadderDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None:
return 0
if not isinstance(answer, str):
return 0.0
answer_words = tuple(s.strip() for s in answer.upper().split(","))
@ -239,17 +239,17 @@ class WordLadderDataset(ProceduralDataset):
# 4. all words are in our vocabulary
if len(answer_words) < 2:
return 0
return 0.0
if answer_words[0] != start_word or answer_words[-1] != end_word:
return 0.01
return 0.0
if not all(len(w) == word_length for w in answer_words):
return 0.01
return 0.0
for i in range(1, len(answer_words)):
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
return 0.01
return 0.0
reward = 1.0
for word in answer_words:

View file

@ -121,8 +121,6 @@ class WordSortingDataset(ProceduralDataset):
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
else:
return 0.01
return 0.0

View file

@ -199,7 +199,7 @@ class ArcAgiDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -106,7 +106,7 @@ class ReArcDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -160,17 +160,15 @@ class BitwiseArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0.
"""
if answer is None:
return 0.0
if isinstance(answer, str):
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
pass
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
return 0.01
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -428,7 +428,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
@ -439,9 +439,6 @@ class CalendarArithmeticDataset(ProceduralDataset):
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
CalendarTask.WEEKDAY_OF_DATE.value,
}:
if not answer:
return 0.0
answer = answer.strip()
oracle_answer = oracle_answer
weekdays = {d.name.title() for d in Weekday}

View file

@ -178,7 +178,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
+ problem_str
)
return {"question": problem_str, "answer": answer, "metadata": {}}
return {"question": problem_str, "answer": str(answer), "metadata": {}}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
@ -189,12 +189,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
"""
if answer is None:
if not isinstance(answer, str):
return 0.0
try:
user_ans: Decimal = Decimal(answer)
correct_ans: Decimal = entry["answer"]
correct_ans: Decimal = Decimal(entry["answer"])
# Determine tolerance based on the desired precision.
precision: int = self.config.max_num_decimal_places
@ -202,9 +202,9 @@ class DecimalArithmeticDataset(ProceduralDataset):
if abs(user_ans - correct_ans) <= tol:
return 1.0
except Exception:
return 0.01
pass
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from decimal import Decimal
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
result -= c
expression = " ".join(expression_parts)
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
try:
q = Decimal(f"0.{'0' * max(decimal_places)}")
result = result.quantize(q)
except InvalidOperation:
pass
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
Returns:
1.0 for exact numerical match, 0.01 otherwise
"""
if answer is None or len(answer.strip()) == 0:
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
try:
student_answer = Decimal(answer.strip())
oracle_answer = Decimal(entry["answer"])
return 1.0 if student_answer == oracle_answer else 0.01
except (ValueError, TypeError, ArithmeticError):
return 0.01
if student_answer == oracle_answer:
return 1.0
except Exception:
pass
return 0.0
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)

View file

@ -138,12 +138,11 @@ class DiceDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("dice", DiceDataset, DiceConfig)

View file

@ -65,14 +65,13 @@ class NumberFormatDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["metadata"]["solution"]
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
try:
answer = float(answer.strip().replace(",", ""))
if abs(answer - oracle_answer) < 1e-2:
return 1.0
return 0.01
except:
return 0.0
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -44,10 +44,8 @@ class PowerFunctionDataset(ProceduralDataset):
return 1.0
elif difference < 1e-1:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -246,7 +246,7 @@ class TimeIntervalsDataset(ProceduralDataset):
Returns a score between 0 and 1, with partial credit for answers that are
close to correct in the appropriate units/format
"""
if not answer:
if not isinstance(answer, str):
return 0.0
expected = entry["answer"]

View file

@ -121,20 +121,23 @@ int main() {{
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
if answer != entry["answer"]:
if entry["answer"] in answer.splitlines():
# We can be quite confident that the correct answer was given
# It was likely just given alongside an explanation
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
if entry["answer"] in answer:
# Since answers are English words, some risk of the response coincidentally containing the answer
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
return 0.01
else:
if answer == entry["answer"]:
return 1.0 # Yay
if entry["answer"] in answer.splitlines():
# We can be quite confident that the correct answer was given
# It was likely just given alongside an explanation
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
if entry["answer"] in answer:
# Since answers are English words, some risk of the response coincidentally containing the answer
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
return 0.0
# Register the dataset
register_dataset("bf", BFDataset, BFConfig)

View file

@ -182,7 +182,7 @@ class FigletFontDataset(ProceduralDataset):
"""
correct_word = entry["answer"]
if not answer:
if not isinstance(answer, str):
return 0.0 # No answer given
# Normalize case

View file

@ -110,19 +110,17 @@ class NeedleHaystackDataset(ProceduralDataset):
Returns:
float: The computed score between 0.0 and 1.0.
"""
if isinstance(answer, str):
correct_word = entry["answer"]
correct_word = entry["answer"]
if not answer:
return 0.0 # No answer given
# Normalize case
answer = answer.replace(" ", "").strip().lower()
correct_word = correct_word.strip().lower()
# Normalize case
answer = answer.replace(" ", "").strip().lower()
correct_word = correct_word.strip().lower()
if answer == correct_word:
return 1.0 # Correct!
if answer == correct_word:
return 1.0 # Correct!
return 0.01
return 0.0
# Register the dataset

View file

@ -132,12 +132,10 @@ class RectangleCountDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)

View file

@ -69,9 +69,6 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
reward = 1.0
elif oracle_answer in answer:
reward = len(oracle_answer) / len(answer)
else:
reward = 0.01
return reward

View file

@ -1,10 +0,0 @@
+ + + + + + +
+ - * - - - +
+ - - - $ - +
+ X - - @ - +
+ - - - - - +
+ $ - + - - +
+ + - - - - +
+ X @ - $ - +
+ + - - - - +
+ + + + + + +

View file

@ -1,5 +0,0 @@
+ + + + + + +
+ * - @ - X +
+ + - @ - + +
+ X - - - - +
+ + + + + + +

View file

@ -1,6 +0,0 @@
- - + + + + + +
- + + - - - * +
+ + - - - + X +
+ X - @ - @ @ +
+ X X @ - - - +
+ + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + - - -
- + X - - X + - - -
+ + - @ @ + + - - -
+ - - - - + + - - -
+ - @ - - * + + + +
+ + - - - - - - X +
- + + + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + - -
+ + X - @ - + + +
+ - - - - - - - +
+ - @ + + X - @ +
+ - - - @ - + - +
+ + + * - X - X +
- - + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + + -
+ + - - + - - + +
+ - @ - - - @ - +
+ - - X * X - - +
+ + @ + + - - + +
+ - - X - - - + -
+ + + + + + + + -

View file

@ -1,9 +0,0 @@
- - - + + + + + + + +
- - - + - - - - - - +
- - + + - - - - @ - +
- + + - - + + - + + +
+ + - - + - - X - - +
+ - - + X @ @ - - + +
+ * + X - - - - + + -
+ + - - - - - + + - -
+ + + + + + + + - - -

View file

@ -1,6 +0,0 @@
+ + + + + + + +
+ - - @ - X * +
+ - @ - - + X +
+ X X @ - @ @ +
+ X X @ - - - +
+ + + + + + + +

View file

@ -13,7 +13,7 @@ from reasoning_gym.games.contrib.sokoban.src.utils import (
)
def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
def astar(matrix, player_pos, debug: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
# print(f'A* - {heuristic.title()} Heuristic')
heur = "[A*]" if heuristic == "manhattan" else "[Dijkstra]"
shape = matrix.shape
@ -67,15 +67,18 @@ def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
return (path + direction[move], depth + 1)
if debug:
print(f"{heur} Solution Depth: {depth + 1}\n{path + direction[move]}", 20)
print(f"{heur} Solution not found!\n")
if depth > max_depth:
break
if debug:
print(f"{heur} Solution Not Found!\nDepth {depth + 1}", 20)
return (None, -1 if not heap else depth + 1)
def solve_astar(puzzle, visualizer=False, heuristic="manhattan"):
def solve_astar(puzzle, visualizer: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
matrix = puzzle
where = np.where((matrix == "*") | (matrix == "%"))
player_pos = where[0][0], where[1][0]
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic)
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic, max_depth=max_depth)

View file

@ -29,8 +29,7 @@ class PuzzleElement:
class Game:
def __init__(self, width=19, height=10, level=None, path=None):
self.level = level
def __init__(self, width=19, height=10, path=None):
self.width = width
self.height = height
self.puzzle = np.empty((height, width), dtype=PuzzleElement)
@ -39,7 +38,7 @@ class Game:
self.puzzle_size = None
self.pad_x = 0
self.pad_y = 0
self.path = path or f"levels/lvl{level}.dat"
self.path = path
if path:
if type(self) == Game:
@ -108,7 +107,7 @@ class Game:
# Calculate puzzle size and padding
self.puzzle_size = (len(data), len(data[0]) if len(data) > 0 else 0)
pad_x = (self.width - self.puzzle_size[1] - 2) // 2 # -2 matches original file-based logic
pad_x = (self.width - self.puzzle_size[1]) // 2
pad_y = (self.height - self.puzzle_size[0]) // 2
self.pad_x, self.pad_y = pad_x, pad_y
@ -140,15 +139,15 @@ class Game:
class ReverseGame(Game):
def __init__(self, rng: Random, width=19, height=10, level=None):
super().__init__(width, height, level)
def __init__(self, rng: Random, width: int = 19, height: int = 10):
super().__init__(width, height)
self.rng = rng
self.pad_x = 0
self.pad_y = 0
def load_puzzle(self, puzzle):
self.puzzle_size = (len(puzzle), len(puzzle[0]) if len(puzzle) > 0 else 0)
pad_x = (self.width - len(puzzle[0]) - 2) // 2
pad_x = (self.width - len(puzzle[0])) // 2
pad_y = (self.height - len(puzzle)) // 2
self.pad_x, self.pad_y = pad_x, pad_y
for i, row in enumerate(puzzle):

View file

@ -7,6 +7,9 @@ from reasoning_gym.games.contrib.sokoban.src.game import Game, ReverseGame
def num_boxes(puzzle_area, min_boxes, max_boxes, min_w, min_h, max_w, max_h):
if min_w == max_w or min_h == max_h or min_boxes == max_boxes:
return max_boxes
m = (max_boxes - min_boxes) / (max_w * max_h - min_w * min_h)
b = min_boxes - m * min_w * min_h
return int(m * puzzle_area + b)
@ -19,31 +22,33 @@ def random_valid(rng: Random, width: int = 10, height: int = 10):
def generate(
rng: Random,
debug: bool = False,
path: str = None,
min_w: int = 6,
min_h: int = 6,
max_w: int = 15,
max_h: int = 10,
min_boxes: int = 4,
max_boxes: int = 10,
max_depth: int = 100,
path: str = None,
) -> tuple[str, str, dict]:
"""
Generates a level with the given configuration parameters.
Parameters:
rng: Random number generator for reproducibility.
visualizer: Whether to visualize the generation process.
path: Path to save the level file (default 'levels/lvl0.dat').
min_w: Minimum width of the puzzle.
min_h: Minimum height of the puzzle.
max_w: Maximum width of the puzzle.
max_h: Maximum height of the puzzle.
min_boxes: Minimum number of boxes.
max_boxes: Maximum number of boxes.
rng: Random number generator
visualizer: Whether to visualize the generation process
min_w: Minimum width of the puzzle
min_h: Minimum height of the puzzle
max_w: Maximum width of the puzzle
max_h: Maximum height of the puzzle
min_boxes: Minimum number of boxes
max_boxes: Maximum number of boxes
max_depth: Maximum search depth
path: Path to save the level file (optional)
Returns:
puzzle_string, solution
"""
path = path or "levels/lvl0.dat"
while True:
width = rng.randint(min_w, max_w)
height = rng.randint(min_h, max_h)
@ -60,7 +65,7 @@ def generate(
puzzle[box_pos] = "$"
boxes_created += 1
boxes_seen.add(box_pos)
reverse_game = ReverseGame(rng=rng, level=0)
reverse_game = ReverseGame(rng=rng, width=width, height=height)
reverse_game.load_puzzle(puzzle)
player = reverse_game.player
counter = round(height * width * rng.uniform(1.8, 3.6))
@ -79,16 +84,19 @@ def generate(
out_of_place_boxes = np.sum([str(x) == "@" for x in matrix.flatten()])
if out_of_place_boxes >= boxes // 2:
# Optionally save the puzzle to a file:
# np.savetxt(path, matrix, fmt='%s')
if path:
np.savetxt(path, matrix, fmt="%s")
puzzle_str = player.puzzle_to_string(matrix)
grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")]
grid_array = np.array(grid_list)
solution, _ = solve_astar(grid_array)
solution, depth = solve_astar(grid_array, max_depth=max_depth)
if solution is None:
continue # retry generation
if debug:
print(f"solution={solution}")
game = Game()
game = Game(width=width, height=height)
game.load_puzzle_matrix(grid_array)
for step, move in enumerate(solution):

View file

@ -618,7 +618,7 @@ class FutoshikiDataset(ProceduralDataset):
return grid
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str):
return 0.0
oracle_answer = entry["answer"]

View file

@ -314,36 +314,35 @@ class KnightSwapDataset(ProceduralDataset):
- 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length)
- A proportional score for correct but longer solutions
- 0.05 for valid moves that don't solve the puzzle
- 0.01 for invalid format
- 0.0 for None
- 0.0 for invalid format or None
"""
if answer is None:
if not isinstance(answer, str):
return 0.0
answer = answer.strip()
if not answer:
return 0.01
if len(answer) == 0:
return 0.0
# Handle impossible puzzles
if not entry["metadata"]["is_possible"]:
return 1.0 if answer.lower() == "no" else 0.01
return 1.0 if answer.lower() == "no" else 0.0
# Handle "No" answer for possible puzzles
if answer.lower() == "no":
return 0.01
return 0.0
try:
# Parse moves from JSON list
move_list = json.loads(answer)
if not isinstance(move_list, list):
return 0.01
return 0.0
# Parse moves
moves = []
for move_str in move_list:
color, start, end = move_str.split(",")
if color not in ("w", "B"):
return 0.01
return 0.0
moves.append((color, start, end))
# Validate and apply moves
@ -357,13 +356,13 @@ class KnightSwapDataset(ProceduralDataset):
for color, start, end in moves:
if color != current_turn:
return 0.01
return 0.0
if start not in pieces or pieces[start] != color:
return 0.01
return 0.0
if end not in board[start]:
return 0.01
return 0.0
if end in pieces and pieces[end] is not None:
return 0.01
return 0.0
# Apply move
pieces[end] = pieces[start]
@ -390,7 +389,7 @@ class KnightSwapDataset(ProceduralDataset):
return 0.05
except Exception:
return 0.01
return 0.0
register_dataset("knight_swap", KnightSwapDataset, KnightSwapConfig)

View file

@ -195,7 +195,7 @@ class MiniSudokuDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]

View file

@ -138,8 +138,8 @@ class NQueensDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
valid_solutions = entry["metadata"]["valid_answers"]
if answer is not None:
if isinstance(answer, str):
valid_solutions = entry["metadata"]["valid_answers"]
if answer in valid_solutions:
return 1.0
try:
@ -147,7 +147,7 @@ class NQueensDataset(ProceduralDataset):
if answer in valid_solutions:
return 0.5
except Exception as e:
return 0.01
pass
return 0.0

View file

@ -171,7 +171,7 @@ class RushHourDataset(ProceduralDataset):
Returns:
1.0 if solution reaches goal state, 0.0 otherwise
"""
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
try:

View file

@ -11,20 +11,26 @@ from ..factory import ProceduralDataset, register_dataset
class SokobanConfig:
"""Configuration for sokoban puzzle generation"""
min_w: int = 6 # Minimum width of the puzzle.
min_h: int = 6 # Minimum height of the puzzle.
max_w: int = 10 # Maximum width of the puzzle.
max_h: int = 10 # Maximum height of the puzzle.
min_boxes: int = 6 # Minimum number of boxes.
max_boxes: int = 10 # Maximum number of boxes.
min_w: int = 6 # Minimum width of the puzzle
min_h: int = 6 # Minimum height of the puzzle
max_w: int = 10 # Maximum width of the puzzle
max_h: int = 10 # Maximum height of the puzzle
min_boxes: int = 4 # Minimum number of boxes
max_boxes: int = 10 # Maximum number of boxes
max_depth: int = 80 # Maximum search depth
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert 0 < self.max_w <= 20
assert 0 < self.max_h <= 20
assert self.min_h > 0
assert self.min_w > 0
assert self.min_w <= self.max_w, "min_w must be lte max_w"
assert self.min_h <= self.max_h, "min_h must be lte max_h"
assert self.min_boxes <= self.max_boxes, "min_boxes must be lte max_boxes"
assert self.max_depth > 1
class SokobanDataset(ProceduralDataset):
@ -58,7 +64,16 @@ class SokobanDataset(ProceduralDataset):
# Make the Sokoban!
rng = Random(self.seed + idx)
gamestr, solution, difficulty = self._generate(rng=rng)
gamestr, solution, difficulty = self._generate(
rng=rng,
min_w=self.config.min_w,
min_h=self.config.min_h,
max_w=self.config.max_w,
max_h=self.config.max_h,
min_boxes=self.config.min_boxes,
max_boxes=self.config.max_boxes,
max_depth=self.config.max_depth,
)
return {
"question": """You are going to solve a 'sokoban' puzzle.
@ -93,14 +108,15 @@ Here is your puzzle:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
try:
grid_list = [list(line) for line in entry["metadata"]["gamestr"].replace(" ", "").strip().split("\n")]
matrix = np.array(grid_list)
game = self._Game()
h, w = matrix.shape
game = self._Game(height=h, width=w)
game.load_puzzle_matrix(matrix)
for move in answer:
@ -108,10 +124,10 @@ Here is your puzzle:
if self._is_solved(game.get_curr_state()):
return 1.0
except Exception as e:
return 0.01
except:
pass
return 0.1
return 0.0
register_dataset("sokoban", SokobanDataset, SokobanConfig)

View file

@ -214,7 +214,7 @@ class SudokuDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]

View file

@ -266,7 +266,7 @@ class HanoiDataset(ProceduralDataset):
start_peg=peg_labels[start_peg],
target_peg=peg_labels[target_peg],
),
"answer": solution,
"answer": "\n".join(solution),
"metadata": {
"num_disks": num_disks,
"num_pegs": num_pegs,
@ -383,24 +383,14 @@ class HanoiDataset(ProceduralDataset):
Expected behavior:
- Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0.
- A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count
- A badly formatted answer gives a minimal reward (0.01).
- An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05).
- An empty string gives 0.01.
- None gives 0.0.
- A badly formatted or empty answer gives 0.0
"""
if answer is None:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
if answer == "":
return 0.01
# If answer is a string, split it into lines; if it's already a list, use it directly.
if isinstance(answer, str):
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
elif isinstance(answer, list):
moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()]
else:
return 0.0
# Spilt answer string it into lines
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
# Build the initial peg state from metadata.
metadata = entry["metadata"]
@ -418,11 +408,11 @@ class HanoiDataset(ProceduralDataset):
try:
disk, from_peg, to_peg = self._parse_move(move)
except Exception:
return 0.01 # Invalid move format
return 0.0 # Invalid move format
# Validate the move using existing _validate_move method.
if not self._validate_move(peg_state, move):
return 0.01
return 0.0
# Execute the move.
peg_state[from_peg].pop()

View file

@ -358,16 +358,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
if isinstance(answer, str):
try:
answer_formatted = answer.strip().lower()
solved = answer_formatted == entry["answer"].strip().lower()
if solved:
oracle_answer = entry["answer"].strip().lower()
if answer_formatted == oracle_answer:
reward = 1.0
else:
reward = 0.01
except:
reward = 0.01
pass
return reward

View file

@ -169,21 +169,15 @@ Buttons:
The function awards 1.0 for a correct answer and less otherwise.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
# Get correct solution from metadata
correct_solution = entry["metadata"].get("solution_path", [])
# Normalize both answers
def normalize_seq(seq):
"""Handle both string and list inputs by converting to string first"""
# Convert sequence to string representation if it's a list
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
def normalize_seq(seq: str) -> list[str]:
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
user_sequence = normalize_seq(answer)
target_sequence = normalize_seq("".join(correct_solution))
target_sequence = normalize_seq(entry["answer"])
# Exact sequence match required
if user_sequence == target_sequence:
@ -196,7 +190,7 @@ Buttons:
return 1.0 # Different answer, but qually correct
return 0.5 # Alternative scoring - you're correct, but not optimal
return 0.1
return 0.0
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
"""Simulate button presses to verify solutions"""

View file

@ -125,8 +125,8 @@ class ShortestPathDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
oracle_answer = entry["answer"].strip()
answer = answer.strip()
# Exact answer
@ -145,8 +145,6 @@ class ShortestPathDataset(ProceduralDataset):
elif self._is_valid_path(matrix, answer):
return 0.5
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -401,16 +401,14 @@ class CircuitLogicDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None or len(answer) == 0:
return 0.0
if isinstance(answer, str) and len(answer) > 0:
oracle_answer = entry["answer"]
if oracle_answer == answer:
return 1.0
elif oracle_answer == answer.strip():
return len(oracle_answer) / len(answer)
oracle_answer = entry["answer"]
if oracle_answer == answer:
return 1.0
elif oracle_answer == answer.strip():
return len(oracle_answer) / len(answer)
return 0.01
return 0.0
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)

View file

@ -489,7 +489,7 @@ class KnightsKnavesDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score an answer against the oracle answer."""
if answer is None or len(answer) == 0:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
try:
@ -506,11 +506,9 @@ class KnightsKnavesDataset(ProceduralDataset):
if matching > 0:
return 0.3 + (0.7 * matching / len(oracle_assignments))
return 0.01
except Exception:
# If parsing fails, give minimal credit
return 0.01
pass
return 0.0
register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig)

View file

@ -295,7 +295,7 @@ class PropositionalLogicDataset(ProceduralDataset):
def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float:
"""Robust scoring implementation for propositional logic answers"""
if not answer:
if not isinstance(answer, str):
return 0.0
try:
@ -304,7 +304,7 @@ class PropositionalLogicDataset(ProceduralDataset):
valid_vars = set(entry["metadata"]["variables"])
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
if any(var not in valid_vars for var in answer_vars):
return 0.01
return 0.0
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
answer_expr = Expression.from_string(cleaned_answer)
@ -316,7 +316,7 @@ class PropositionalLogicDataset(ProceduralDataset):
return 1.0
return 0.05
except (ValueError, KeyError, AttributeError):
return 0.01
return 0.0
def _is_trivial(self, expr: Expression) -> bool:
"""Check for trivial tautologies like P ¬P"""

View file

@ -339,9 +339,7 @@ class SelfReferenceDataset(ProceduralDataset):
# Solve puzzle
solutions = solve_puzzle_dynamic(puzzle)
for idx, sol in enumerate(solutions, start=1):
sol_str = ["True" if s else "False" for s in sol]
answer = len(solutions)
answer = str(len(solutions))
return {
"question": puzz_s,
@ -362,12 +360,10 @@ class SelfReferenceDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if str(answer) != str(entry["answer"]):
return 0.1
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer == str(entry["answer"]):
return 1.0 # Yay
return 0.0
register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig)

View file

@ -68,12 +68,10 @@ class ZebraDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig)

View file

@ -103,7 +103,6 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
"""
reward = 0.0
if answer is not None and len(answer) > 0:
reward = 0.01
try:
if strip_commas:
answer = answer.replace(",", "")