Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent 061282e373
commit 5d7fbac0ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
106 changed files with 403 additions and 507 deletions

View file

@ -103,6 +103,10 @@ class ComplexArithmeticDataset(ProceduralDataset):
# Normalize the answer string by removing spaces and converting to lowercase
answer = answer.replace(" ", "").lower()
# remove brackets
while len(answer) > 1 and answer[0] == "(" and answer[-1] == ")":
answer = answer[1:-1]
# Convert mathematical notation 'i' to Python's 'j' for complex numbers
answer = answer.replace("i", "j")

View file

@ -77,9 +77,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
"Evaluate the indefinite integral: ∫ {integrand} dx",
]
self.added_instruction = """
In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number).
"""
def _get_outer_constant(self, rng: random.Random) -> int:
@ -245,7 +246,7 @@ In addition, when doing calculation, use the following instructions together wit
"""Determine if the solution provided solves the problem"""
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
if isinstance(answer, str):
try:
var = metadata["variable"]
x = sympy.Symbol(var)
@ -258,12 +259,8 @@ In addition, when doing calculation, use the following instructions together wit
# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -27,8 +27,9 @@ class PolynomialEquationsConfig:
seed: Optional[int] = None
size: int = 500
# reward function hyperparameters
penalty_missing_factor = 0.1
penalty_extra_factor = 0.05
penalty_missing_factor = 0.5
penalty_extra_factor = 0.5
exp_distance_factor = -10.0
def validate(self) -> None:
"""Validate configuration parameters."""
@ -62,12 +63,15 @@ class PolynomialEquationsDataset(ProceduralDataset):
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
]
self.added_instruction = """
In solving the equations, please abide by the following instruction:
## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc.
## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005".
## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "".
## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers.
## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p.
In solving equations, please follow these instructions:
1. Provide all answers as comma-separated decimal values. For example: "-0.3773, 0.4005"
2. For solutions that can be expressed in exact form (like "u = 2 + sqrt(4560)/172" and "u = 2 - sqrt(4560)/172"), convert them to decimal form in your final answer.
3. If there are no real values that satisfy the equation, report your answer as an empty string: ""
4. Format your answer based on the number of solutions:
- For 1 solution: a single decimal number
- For 2 solutions: two comma-separated decimal numbers
- For 3 or more solutions: all values as comma-separated decimal numbers
5. Round all decimal values to 4 decimal places (rounding down when the 5th decimal place is 5 or greater).
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -238,7 +242,7 @@ In solving the equations, please abide by the following instruction:
# Remove matched oracle solution
oracle_solutions.pop(matched_distance_index)
# Exponential decay reward
total_reward += math.exp(-matched_distance)
total_reward += math.exp(matched_distance * self.config.exp_distance_factor)
else:
# Extra predicted solution
extra_solutions += 1

View file

@ -69,9 +69,9 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
"Calculate the following: {polynomial_expr}",
]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers.
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -106,10 +106,9 @@ In addition, When doing calculation, Use the following instructions together wit
return {
"question": question,
"answer": product,
"answer": str(product),
"metadata": {
"polynomial_expr": str(polynomial_expr),
"result": str(product),
"variables": list(product.free_symbols),
},
}
@ -147,21 +146,16 @@ In addition, When doing calculation, Use the following instructions together wit
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
target_poly = sp.parse_expr(entry["answer"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if predicted_poly == target_poly:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except Exception:
reward = 0.01
reward = 0.0
return reward

View file

@ -42,9 +42,9 @@ class SimpleIntegrationDataset(ProceduralDataset):
"Evaluate the indefinite integral: ∫ {integrand} dx",
]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
When performing calculations, please follow these guidelines:
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
"""
super().__init__(config=config, seed=config.seed, size=config.size)
@ -103,12 +103,8 @@ In addition, When doing calculation, Use the following instructions together wit
# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -130,12 +130,9 @@ Return the final state of the program.
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer != entry["answer"]:
return 0.01
else:
if answer == entry["answer"]:
return 1.0 # Yay
return 0.0
# Register the dataset

View file

@ -108,9 +108,9 @@ class BinaryMatrixDataset(ProceduralDataset):
# check if answer is python list of lists
answer = self._matrix_to_str(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
return 0.0
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -200,7 +200,7 @@ class CryptarithmDataset(ProceduralDataset):
Returns:
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
correct_mapping = {}

View file

@ -106,7 +106,7 @@ class GameOfLifeDataset(ProceduralDataset):
ans_arr = json.loads(answer)
correct_arr = json.loads(entry["answer"])
except Exception:
return 0.01
return 0.0
total_cells = 0
correct_cells = 0

View file

@ -228,12 +228,13 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
try:
danswer = json.loads(answer)
solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer)
if not solved:
return 0.01 # json was parsable but solution incorrect
else:
if solved:
return 1.0 # Yay
else:
return 0.01 # json parsable
except Exception:
return 0.0
pass
return 0.0
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)

View file

@ -95,7 +95,7 @@ class GroupAnagramsDataset(ProceduralDataset):
if answer_str == oracle_str:
reward = 1.0
else:
reward = 0.01
reward = 0.01 # json parsable
except Exception:
reward = 0.0
return reward

View file

@ -303,11 +303,11 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
danswer = json.loads(answer)
valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer)
if not valid:
return 0.01
return 0.01 # json parsable
else:
return 1.0 # Yay
except Exception as e:
return 0.01
return 0.0
register_dataset("jugs", JugsDataset, JugsConfig)

View file

@ -116,7 +116,7 @@ class LetterJumbleDataset(ProceduralDataset):
# Each word in the expected answer is worth an equal fraction of 1.0
total_words = len(expected_words)
score_per_word = 1.0 / total_words if total_words else 0
score_per_word = 1.0 / total_words if total_words > 0 else 0
# Calculate scores word by word
scores = []
@ -142,18 +142,16 @@ class LetterJumbleDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
oracle_answer = entry["answer"].strip().lower()
if answer:
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
return 0.01
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)

View file

@ -144,8 +144,6 @@ class ManipulateMatrixDataset(ProceduralDataset):
if oracle_answer in answer:
return len(oracle_answer) / len(answer)
else:
return 0.01
return 0.0

View file

@ -92,14 +92,14 @@ class PalindromeDataset(ProceduralDataset):
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.01.
- An empty string gives 0.0
- None gives 0.0.
"""
if answer is None or not isinstance(answer, str):
return 0.0 # No answer given
if answer == "":
return 0.01
return 0.0
metadata = entry["metadata"]
answer = answer.strip().lower()

View file

@ -95,9 +95,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
if answer == oracle:
return 1.0
return 0.01
except Exception:
return 0.0
pass
return 0.0
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:

View file

@ -80,7 +80,7 @@ class PoolMatrixDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer based on the metadata"""
if not answer:
if not isinstance(answer, str):
return 0.0
reward = 0.0
@ -91,8 +91,6 @@ class PoolMatrixDataset(ProceduralDataset):
reward = 1.0
elif oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except Exception:
pass
return reward

View file

@ -108,14 +108,12 @@ class RansomNoteDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if isinstance(answer, str):
s_answer = answer.strip()
if s_answer == str(entry["answer"]):
return 1.0
s_answer = answer.strip()
if not s_answer == str(entry["answer"]):
return 0.01
else:
return 1.0
return 0.0
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)

View file

@ -110,7 +110,7 @@ class SentenceReorderingDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -52,14 +52,14 @@ class SpellBackwardDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
try:
if expected_answer.lower() == answer.lower():
reward = 1.0
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -126,11 +126,9 @@ class SpiralMatrixDataset(ProceduralDataset):
try:
answer = " ".join(str(item) for item in eval(answer))
if answer == oracle_answer:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0

View file

@ -75,7 +75,7 @@ class StringInsertionDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
if answer == oracle_answer:
return 1.0
else:
@ -83,9 +83,9 @@ class StringInsertionDataset(ProceduralDataset):
# check if answer is python list of characters
answer = "".join(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -221,8 +221,8 @@ class WordLadderDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None:
return 0
if not isinstance(answer, str):
return 0.0
answer_words = tuple(s.strip() for s in answer.upper().split(","))
@ -239,17 +239,17 @@ class WordLadderDataset(ProceduralDataset):
# 4. all words are in our vocabulary
if len(answer_words) < 2:
return 0
return 0.0
if answer_words[0] != start_word or answer_words[-1] != end_word:
return 0.01
return 0.0
if not all(len(w) == word_length for w in answer_words):
return 0.01
return 0.0
for i in range(1, len(answer_words)):
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
return 0.01
return 0.0
reward = 1.0
for word in answer_words:

View file

@ -121,8 +121,6 @@ class WordSortingDataset(ProceduralDataset):
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
else:
return 0.01
return 0.0

View file

@ -199,7 +199,7 @@ class ArcAgiDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -106,7 +106,7 @@ class ReArcDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -160,17 +160,15 @@ class BitwiseArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0.
"""
if answer is None:
return 0.0
if isinstance(answer, str):
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
pass
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
return 0.01
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -428,7 +428,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
@ -439,9 +439,6 @@ class CalendarArithmeticDataset(ProceduralDataset):
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
CalendarTask.WEEKDAY_OF_DATE.value,
}:
if not answer:
return 0.0
answer = answer.strip()
oracle_answer = oracle_answer
weekdays = {d.name.title() for d in Weekday}

View file

@ -178,7 +178,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
+ problem_str
)
return {"question": problem_str, "answer": answer, "metadata": {}}
return {"question": problem_str, "answer": str(answer), "metadata": {}}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
@ -189,12 +189,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
"""
if answer is None:
if not isinstance(answer, str):
return 0.0
try:
user_ans: Decimal = Decimal(answer)
correct_ans: Decimal = entry["answer"]
correct_ans: Decimal = Decimal(entry["answer"])
# Determine tolerance based on the desired precision.
precision: int = self.config.max_num_decimal_places
@ -202,9 +202,9 @@ class DecimalArithmeticDataset(ProceduralDataset):
if abs(user_ans - correct_ans) <= tol:
return 1.0
except Exception:
return 0.01
pass
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from decimal import Decimal
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
result -= c
expression = " ".join(expression_parts)
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
try:
q = Decimal(f"0.{'0' * max(decimal_places)}")
result = result.quantize(q)
except InvalidOperation:
pass
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
Returns:
1.0 for exact numerical match, 0.01 otherwise
"""
if answer is None or len(answer.strip()) == 0:
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
try:
student_answer = Decimal(answer.strip())
oracle_answer = Decimal(entry["answer"])
return 1.0 if student_answer == oracle_answer else 0.01
except (ValueError, TypeError, ArithmeticError):
return 0.01
if student_answer == oracle_answer:
return 1.0
except Exception:
pass
return 0.0
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)

View file

@ -138,12 +138,11 @@ class DiceDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("dice", DiceDataset, DiceConfig)

View file

@ -65,14 +65,13 @@ class NumberFormatDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["metadata"]["solution"]
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
try:
answer = float(answer.strip().replace(",", ""))
if abs(answer - oracle_answer) < 1e-2:
return 1.0
return 0.01
except:
return 0.0
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -44,10 +44,8 @@ class PowerFunctionDataset(ProceduralDataset):
return 1.0
elif difference < 1e-1:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -246,7 +246,7 @@ class TimeIntervalsDataset(ProceduralDataset):
Returns a score between 0 and 1, with partial credit for answers that are
close to correct in the appropriate units/format
"""
if not answer:
if not isinstance(answer, str):
return 0.0
expected = entry["answer"]

View file

@ -121,20 +121,23 @@ int main() {{
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
if answer != entry["answer"]:
if entry["answer"] in answer.splitlines():
# We can be quite confident that the correct answer was given
# It was likely just given alongside an explanation
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
if entry["answer"] in answer:
# Since answers are English words, some risk of the response coincidentally containing the answer
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
return 0.01
else:
if answer == entry["answer"]:
return 1.0 # Yay
if entry["answer"] in answer.splitlines():
# We can be quite confident that the correct answer was given
# It was likely just given alongside an explanation
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
if entry["answer"] in answer:
# Since answers are English words, some risk of the response coincidentally containing the answer
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
return 0.0
# Register the dataset
register_dataset("bf", BFDataset, BFConfig)

View file

@ -182,7 +182,7 @@ class FigletFontDataset(ProceduralDataset):
"""
correct_word = entry["answer"]
if not answer:
if not isinstance(answer, str):
return 0.0 # No answer given
# Normalize case

View file

@ -110,19 +110,17 @@ class NeedleHaystackDataset(ProceduralDataset):
Returns:
float: The computed score between 0.0 and 1.0.
"""
if isinstance(answer, str):
correct_word = entry["answer"]
correct_word = entry["answer"]
if not answer:
return 0.0 # No answer given
# Normalize case
answer = answer.replace(" ", "").strip().lower()
correct_word = correct_word.strip().lower()
# Normalize case
answer = answer.replace(" ", "").strip().lower()
correct_word = correct_word.strip().lower()
if answer == correct_word:
return 1.0 # Correct!
if answer == correct_word:
return 1.0 # Correct!
return 0.01
return 0.0
# Register the dataset

View file

@ -132,12 +132,10 @@ class RectangleCountDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)

View file

@ -69,9 +69,6 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
reward = 1.0
elif oracle_answer in answer:
reward = len(oracle_answer) / len(answer)
else:
reward = 0.01
return reward

View file

@ -1,10 +0,0 @@
+ + + + + + +
+ - * - - - +
+ - - - $ - +
+ X - - @ - +
+ - - - - - +
+ $ - + - - +
+ + - - - - +
+ X @ - $ - +
+ + - - - - +
+ + + + + + +

View file

@ -1,5 +0,0 @@
+ + + + + + +
+ * - @ - X +
+ + - @ - + +
+ X - - - - +
+ + + + + + +

View file

@ -1,6 +0,0 @@
- - + + + + + +
- + + - - - * +
+ + - - - + X +
+ X - @ - @ @ +
+ X X @ - - - +
+ + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + - - -
- + X - - X + - - -
+ + - @ @ + + - - -
+ - - - - + + - - -
+ - @ - - * + + + +
+ + - - - - - - X +
- + + + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + - -
+ + X - @ - + + +
+ - - - - - - - +
+ - @ + + X - @ +
+ - - - @ - + - +
+ + + * - X - X +
- - + + + + + + +

View file

@ -1,7 +0,0 @@
- + + + + + + + -
+ + - - + - - + +
+ - @ - - - @ - +
+ - - X * X - - +
+ + @ + + - - + +
+ - - X - - - + -
+ + + + + + + + -

View file

@ -1,9 +0,0 @@
- - - + + + + + + + +
- - - + - - - - - - +
- - + + - - - - @ - +
- + + - - + + - + + +
+ + - - + - - X - - +
+ - - + X @ @ - - + +
+ * + X - - - - + + -
+ + - - - - - + + - -
+ + + + + + + + - - -

View file

@ -1,6 +0,0 @@
+ + + + + + + +
+ - - @ - X * +
+ - @ - - + X +
+ X X @ - @ @ +
+ X X @ - - - +
+ + + + + + + +

View file

@ -13,7 +13,7 @@ from reasoning_gym.games.contrib.sokoban.src.utils import (
)
def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
def astar(matrix, player_pos, debug: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
# print(f'A* - {heuristic.title()} Heuristic')
heur = "[A*]" if heuristic == "manhattan" else "[Dijkstra]"
shape = matrix.shape
@ -67,15 +67,18 @@ def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
return (path + direction[move], depth + 1)
if debug:
print(f"{heur} Solution Depth: {depth + 1}\n{path + direction[move]}", 20)
print(f"{heur} Solution not found!\n")
if depth > max_depth:
break
if debug:
print(f"{heur} Solution Not Found!\nDepth {depth + 1}", 20)
return (None, -1 if not heap else depth + 1)
def solve_astar(puzzle, visualizer=False, heuristic="manhattan"):
def solve_astar(puzzle, visualizer: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
matrix = puzzle
where = np.where((matrix == "*") | (matrix == "%"))
player_pos = where[0][0], where[1][0]
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic)
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic, max_depth=max_depth)

View file

@ -29,8 +29,7 @@ class PuzzleElement:
class Game:
def __init__(self, width=19, height=10, level=None, path=None):
self.level = level
def __init__(self, width=19, height=10, path=None):
self.width = width
self.height = height
self.puzzle = np.empty((height, width), dtype=PuzzleElement)
@ -39,7 +38,7 @@ class Game:
self.puzzle_size = None
self.pad_x = 0
self.pad_y = 0
self.path = path or f"levels/lvl{level}.dat"
self.path = path
if path:
if type(self) == Game:
@ -108,7 +107,7 @@ class Game:
# Calculate puzzle size and padding
self.puzzle_size = (len(data), len(data[0]) if len(data) > 0 else 0)
pad_x = (self.width - self.puzzle_size[1] - 2) // 2 # -2 matches original file-based logic
pad_x = (self.width - self.puzzle_size[1]) // 2
pad_y = (self.height - self.puzzle_size[0]) // 2
self.pad_x, self.pad_y = pad_x, pad_y
@ -140,15 +139,15 @@ class Game:
class ReverseGame(Game):
def __init__(self, rng: Random, width=19, height=10, level=None):
super().__init__(width, height, level)
def __init__(self, rng: Random, width: int = 19, height: int = 10):
super().__init__(width, height)
self.rng = rng
self.pad_x = 0
self.pad_y = 0
def load_puzzle(self, puzzle):
self.puzzle_size = (len(puzzle), len(puzzle[0]) if len(puzzle) > 0 else 0)
pad_x = (self.width - len(puzzle[0]) - 2) // 2
pad_x = (self.width - len(puzzle[0])) // 2
pad_y = (self.height - len(puzzle)) // 2
self.pad_x, self.pad_y = pad_x, pad_y
for i, row in enumerate(puzzle):

View file

@ -7,6 +7,9 @@ from reasoning_gym.games.contrib.sokoban.src.game import Game, ReverseGame
def num_boxes(puzzle_area, min_boxes, max_boxes, min_w, min_h, max_w, max_h):
if min_w == max_w or min_h == max_h or min_boxes == max_boxes:
return max_boxes
m = (max_boxes - min_boxes) / (max_w * max_h - min_w * min_h)
b = min_boxes - m * min_w * min_h
return int(m * puzzle_area + b)
@ -19,31 +22,33 @@ def random_valid(rng: Random, width: int = 10, height: int = 10):
def generate(
rng: Random,
debug: bool = False,
path: str = None,
min_w: int = 6,
min_h: int = 6,
max_w: int = 15,
max_h: int = 10,
min_boxes: int = 4,
max_boxes: int = 10,
max_depth: int = 100,
path: str = None,
) -> tuple[str, str, dict]:
"""
Generates a level with the given configuration parameters.
Parameters:
rng: Random number generator for reproducibility.
visualizer: Whether to visualize the generation process.
path: Path to save the level file (default 'levels/lvl0.dat').
min_w: Minimum width of the puzzle.
min_h: Minimum height of the puzzle.
max_w: Maximum width of the puzzle.
max_h: Maximum height of the puzzle.
min_boxes: Minimum number of boxes.
max_boxes: Maximum number of boxes.
rng: Random number generator
visualizer: Whether to visualize the generation process
min_w: Minimum width of the puzzle
min_h: Minimum height of the puzzle
max_w: Maximum width of the puzzle
max_h: Maximum height of the puzzle
min_boxes: Minimum number of boxes
max_boxes: Maximum number of boxes
max_depth: Maximum search depth
path: Path to save the level file (optional)
Returns:
puzzle_string, solution
"""
path = path or "levels/lvl0.dat"
while True:
width = rng.randint(min_w, max_w)
height = rng.randint(min_h, max_h)
@ -60,7 +65,7 @@ def generate(
puzzle[box_pos] = "$"
boxes_created += 1
boxes_seen.add(box_pos)
reverse_game = ReverseGame(rng=rng, level=0)
reverse_game = ReverseGame(rng=rng, width=width, height=height)
reverse_game.load_puzzle(puzzle)
player = reverse_game.player
counter = round(height * width * rng.uniform(1.8, 3.6))
@ -79,16 +84,19 @@ def generate(
out_of_place_boxes = np.sum([str(x) == "@" for x in matrix.flatten()])
if out_of_place_boxes >= boxes // 2:
# Optionally save the puzzle to a file:
# np.savetxt(path, matrix, fmt='%s')
if path:
np.savetxt(path, matrix, fmt="%s")
puzzle_str = player.puzzle_to_string(matrix)
grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")]
grid_array = np.array(grid_list)
solution, _ = solve_astar(grid_array)
solution, depth = solve_astar(grid_array, max_depth=max_depth)
if solution is None:
continue # retry generation
if debug:
print(f"solution={solution}")
game = Game()
game = Game(width=width, height=height)
game.load_puzzle_matrix(grid_array)
for step, move in enumerate(solution):

View file

@ -618,7 +618,7 @@ class FutoshikiDataset(ProceduralDataset):
return grid
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str):
return 0.0
oracle_answer = entry["answer"]

View file

@ -314,36 +314,35 @@ class KnightSwapDataset(ProceduralDataset):
- 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length)
- A proportional score for correct but longer solutions
- 0.05 for valid moves that don't solve the puzzle
- 0.01 for invalid format
- 0.0 for None
- 0.0 for invalid format or None
"""
if answer is None:
if not isinstance(answer, str):
return 0.0
answer = answer.strip()
if not answer:
return 0.01
if len(answer) == 0:
return 0.0
# Handle impossible puzzles
if not entry["metadata"]["is_possible"]:
return 1.0 if answer.lower() == "no" else 0.01
return 1.0 if answer.lower() == "no" else 0.0
# Handle "No" answer for possible puzzles
if answer.lower() == "no":
return 0.01
return 0.0
try:
# Parse moves from JSON list
move_list = json.loads(answer)
if not isinstance(move_list, list):
return 0.01
return 0.0
# Parse moves
moves = []
for move_str in move_list:
color, start, end = move_str.split(",")
if color not in ("w", "B"):
return 0.01
return 0.0
moves.append((color, start, end))
# Validate and apply moves
@ -357,13 +356,13 @@ class KnightSwapDataset(ProceduralDataset):
for color, start, end in moves:
if color != current_turn:
return 0.01
return 0.0
if start not in pieces or pieces[start] != color:
return 0.01
return 0.0
if end not in board[start]:
return 0.01
return 0.0
if end in pieces and pieces[end] is not None:
return 0.01
return 0.0
# Apply move
pieces[end] = pieces[start]
@ -390,7 +389,7 @@ class KnightSwapDataset(ProceduralDataset):
return 0.05
except Exception:
return 0.01
return 0.0
register_dataset("knight_swap", KnightSwapDataset, KnightSwapConfig)

View file

@ -195,7 +195,7 @@ class MiniSudokuDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]

View file

@ -138,8 +138,8 @@ class NQueensDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
valid_solutions = entry["metadata"]["valid_answers"]
if answer is not None:
if isinstance(answer, str):
valid_solutions = entry["metadata"]["valid_answers"]
if answer in valid_solutions:
return 1.0
try:
@ -147,7 +147,7 @@ class NQueensDataset(ProceduralDataset):
if answer in valid_solutions:
return 0.5
except Exception as e:
return 0.01
pass
return 0.0

View file

@ -171,7 +171,7 @@ class RushHourDataset(ProceduralDataset):
Returns:
1.0 if solution reaches goal state, 0.0 otherwise
"""
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
try:

View file

@ -11,20 +11,26 @@ from ..factory import ProceduralDataset, register_dataset
class SokobanConfig:
"""Configuration for sokoban puzzle generation"""
min_w: int = 6 # Minimum width of the puzzle.
min_h: int = 6 # Minimum height of the puzzle.
max_w: int = 10 # Maximum width of the puzzle.
max_h: int = 10 # Maximum height of the puzzle.
min_boxes: int = 6 # Minimum number of boxes.
max_boxes: int = 10 # Maximum number of boxes.
min_w: int = 6 # Minimum width of the puzzle
min_h: int = 6 # Minimum height of the puzzle
max_w: int = 10 # Maximum width of the puzzle
max_h: int = 10 # Maximum height of the puzzle
min_boxes: int = 4 # Minimum number of boxes
max_boxes: int = 10 # Maximum number of boxes
max_depth: int = 80 # Maximum search depth
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert 0 < self.max_w <= 20
assert 0 < self.max_h <= 20
assert self.min_h > 0
assert self.min_w > 0
assert self.min_w <= self.max_w, "min_w must be lte max_w"
assert self.min_h <= self.max_h, "min_h must be lte max_h"
assert self.min_boxes <= self.max_boxes, "min_boxes must be lte max_boxes"
assert self.max_depth > 1
class SokobanDataset(ProceduralDataset):
@ -58,7 +64,16 @@ class SokobanDataset(ProceduralDataset):
# Make the Sokoban!
rng = Random(self.seed + idx)
gamestr, solution, difficulty = self._generate(rng=rng)
gamestr, solution, difficulty = self._generate(
rng=rng,
min_w=self.config.min_w,
min_h=self.config.min_h,
max_w=self.config.max_w,
max_h=self.config.max_h,
min_boxes=self.config.min_boxes,
max_boxes=self.config.max_boxes,
max_depth=self.config.max_depth,
)
return {
"question": """You are going to solve a 'sokoban' puzzle.
@ -93,14 +108,15 @@ Here is your puzzle:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
try:
grid_list = [list(line) for line in entry["metadata"]["gamestr"].replace(" ", "").strip().split("\n")]
matrix = np.array(grid_list)
game = self._Game()
h, w = matrix.shape
game = self._Game(height=h, width=w)
game.load_puzzle_matrix(matrix)
for move in answer:
@ -108,10 +124,10 @@ Here is your puzzle:
if self._is_solved(game.get_curr_state()):
return 1.0
except Exception as e:
return 0.01
except:
pass
return 0.1
return 0.0
register_dataset("sokoban", SokobanDataset, SokobanConfig)

View file

@ -214,7 +214,7 @@ class SudokuDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]

View file

@ -266,7 +266,7 @@ class HanoiDataset(ProceduralDataset):
start_peg=peg_labels[start_peg],
target_peg=peg_labels[target_peg],
),
"answer": solution,
"answer": "\n".join(solution),
"metadata": {
"num_disks": num_disks,
"num_pegs": num_pegs,
@ -383,24 +383,14 @@ class HanoiDataset(ProceduralDataset):
Expected behavior:
- Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0.
- A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count
- A badly formatted answer gives a minimal reward (0.01).
- An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05).
- An empty string gives 0.01.
- None gives 0.0.
- A badly formatted or empty answer gives 0.0
"""
if answer is None:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
if answer == "":
return 0.01
# If answer is a string, split it into lines; if it's already a list, use it directly.
if isinstance(answer, str):
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
elif isinstance(answer, list):
moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()]
else:
return 0.0
# Spilt answer string it into lines
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
# Build the initial peg state from metadata.
metadata = entry["metadata"]
@ -418,11 +408,11 @@ class HanoiDataset(ProceduralDataset):
try:
disk, from_peg, to_peg = self._parse_move(move)
except Exception:
return 0.01 # Invalid move format
return 0.0 # Invalid move format
# Validate the move using existing _validate_move method.
if not self._validate_move(peg_state, move):
return 0.01
return 0.0
# Execute the move.
peg_state[from_peg].pop()

View file

@ -358,16 +358,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
if isinstance(answer, str):
try:
answer_formatted = answer.strip().lower()
solved = answer_formatted == entry["answer"].strip().lower()
if solved:
oracle_answer = entry["answer"].strip().lower()
if answer_formatted == oracle_answer:
reward = 1.0
else:
reward = 0.01
except:
reward = 0.01
pass
return reward

View file

@ -169,21 +169,15 @@ Buttons:
The function awards 1.0 for a correct answer and less otherwise.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
# Get correct solution from metadata
correct_solution = entry["metadata"].get("solution_path", [])
# Normalize both answers
def normalize_seq(seq):
"""Handle both string and list inputs by converting to string first"""
# Convert sequence to string representation if it's a list
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
def normalize_seq(seq: str) -> list[str]:
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
user_sequence = normalize_seq(answer)
target_sequence = normalize_seq("".join(correct_solution))
target_sequence = normalize_seq(entry["answer"])
# Exact sequence match required
if user_sequence == target_sequence:
@ -196,7 +190,7 @@ Buttons:
return 1.0 # Different answer, but qually correct
return 0.5 # Alternative scoring - you're correct, but not optimal
return 0.1
return 0.0
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
"""Simulate button presses to verify solutions"""

View file

@ -125,8 +125,8 @@ class ShortestPathDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
oracle_answer = entry["answer"].strip()
answer = answer.strip()
# Exact answer
@ -145,8 +145,6 @@ class ShortestPathDataset(ProceduralDataset):
elif self._is_valid_path(matrix, answer):
return 0.5
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -401,16 +401,14 @@ class CircuitLogicDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None or len(answer) == 0:
return 0.0
if isinstance(answer, str) and len(answer) > 0:
oracle_answer = entry["answer"]
if oracle_answer == answer:
return 1.0
elif oracle_answer == answer.strip():
return len(oracle_answer) / len(answer)
oracle_answer = entry["answer"]
if oracle_answer == answer:
return 1.0
elif oracle_answer == answer.strip():
return len(oracle_answer) / len(answer)
return 0.01
return 0.0
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)

View file

@ -489,7 +489,7 @@ class KnightsKnavesDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score an answer against the oracle answer."""
if answer is None or len(answer) == 0:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
try:
@ -506,11 +506,9 @@ class KnightsKnavesDataset(ProceduralDataset):
if matching > 0:
return 0.3 + (0.7 * matching / len(oracle_assignments))
return 0.01
except Exception:
# If parsing fails, give minimal credit
return 0.01
pass
return 0.0
register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig)

View file

@ -295,7 +295,7 @@ class PropositionalLogicDataset(ProceduralDataset):
def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float:
"""Robust scoring implementation for propositional logic answers"""
if not answer:
if not isinstance(answer, str):
return 0.0
try:
@ -304,7 +304,7 @@ class PropositionalLogicDataset(ProceduralDataset):
valid_vars = set(entry["metadata"]["variables"])
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
if any(var not in valid_vars for var in answer_vars):
return 0.01
return 0.0
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
answer_expr = Expression.from_string(cleaned_answer)
@ -316,7 +316,7 @@ class PropositionalLogicDataset(ProceduralDataset):
return 1.0
return 0.05
except (ValueError, KeyError, AttributeError):
return 0.01
return 0.0
def _is_trivial(self, expr: Expression) -> bool:
"""Check for trivial tautologies like P ¬P"""

View file

@ -339,9 +339,7 @@ class SelfReferenceDataset(ProceduralDataset):
# Solve puzzle
solutions = solve_puzzle_dynamic(puzzle)
for idx, sol in enumerate(solutions, start=1):
sol_str = ["True" if s else "False" for s in sol]
answer = len(solutions)
answer = str(len(solutions))
return {
"question": puzz_s,
@ -362,12 +360,10 @@ class SelfReferenceDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if str(answer) != str(entry["answer"]):
return 0.1
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer == str(entry["answer"]):
return 1.0 # Yay
return 0.0
register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig)

View file

@ -68,12 +68,10 @@ class ZebraDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig)

View file

@ -103,7 +103,6 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
"""
reward = 0.0
if answer is not None and len(answer) > 0:
reward = 0.01
try:
if strip_commas:
answer = answer.replace(",", "")

View file

@ -57,7 +57,7 @@ def test_ab_scoring():
# Test wrong answer
wrong_answer = "A# B#" if item["answer"] != "A# B#" else "B# A#"
assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.01
assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.0
# Test None answer
assert dataset.score_answer(answer=None, entry=item) == 0.0

View file

@ -103,7 +103,7 @@ def test_arc_1d_scoring():
assert dataset.score_answer(f"The answer is: {entry['answer']}", entry) > 0.5
# Test incorrect answer
assert dataset.score_answer("wrong answer", entry) == 0.01
assert dataset.score_answer("wrong answer", entry) == 0.0
# Test None answer
assert dataset.score_answer(None, entry) == 0.0

View file

@ -110,7 +110,7 @@ def test_arc_agi_scoring():
assert dataset.score_answer(item["answer"], entry=item) == 1.0
# Test invalid format
assert dataset.score_answer("invalid grid format", entry=item) == 0.01
assert dataset.score_answer("invalid grid format", entry=item) == 0.0
# Test None answer
assert dataset.score_answer(None, entry=item) == 0.0

View file

@ -23,7 +23,7 @@ def test_bf():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.01
assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.0
# Medium
config = BFConfig(seed=43, size=20, difficulty=2)

View file

@ -115,7 +115,7 @@ def test_binary_matrix_answer():
# Answer is a python list (partially correct answer)
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 0.5
assert dataset.score_answer(answer, entry) == 0.1
# Answer is null
answer = None

View file

@ -43,7 +43,7 @@ def test_bitwise_arithmetic_items():
# Test scoring edge cases
assert dataset.score_answer(answer=None, entry=item) == 0.0
assert dataset.score_answer(answer="invalid", entry=item) == 0.01
assert dataset.score_answer(answer="invalid", entry=item) == 0.0
def test_bitwise_arithmetic_difficulty_levels():

View file

@ -58,6 +58,7 @@ def test_complex_arithmetic_scoring():
assert dataset.score_answer("3 + 2i", entry) == 1.0
assert dataset.score_answer("3+2i", entry) == 1.0
assert dataset.score_answer("3.0 + 2.0i", entry) == 1.0
assert dataset.score_answer("((3.0 + 2.0i ) )", entry) == 1.0
# Test answers with small errors (should get high but < 1.0 scores)
print(dataset.score_answer("3.1 + 2i", entry))

View file

@ -36,7 +36,7 @@ def test_reseeding_dataset_iteration():
# Test score_answer forwarding
test_item = next(iter(infinite_dataset))
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
assert infinite_dataset.score_answer("wrong", test_item) == 0.0
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0

View file

@ -0,0 +1,17 @@
import reasoning_gym
from reasoning_gym.factory import DATASETS
def test_score_answer_consistency():
for dataset_name in DATASETS.keys():
if dataset_name == "composite":
continue
dataset = reasoning_gym.create_dataset(dataset_name, size=10, seed=1234)
for entry in dataset:
assert entry["answer"] is None or isinstance(
entry["answer"], str
), f"{dataset_name} answer must be str, is {type(entry['answer'])}"
if entry["answer"] is not None:
assert (
dataset.score_answer(answer=entry["answer"], entry=entry) == 1.0
), f"inconsistent score_answer {dataset_name}"

View file

@ -242,11 +242,11 @@ def test_decimal_precision_scoring():
assert dataset.score_answer("0.3", {"answer": "0.300"}) == 1.0
# Test incorrect answers
assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.01
assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.01
assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.0
# Test invalid inputs
assert dataset.score_answer(None, {"answer": "1.200"}) == 0.0
assert dataset.score_answer("", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.01
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.01
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0

View file

@ -50,7 +50,7 @@ def test_game_of_life_basic_properties():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
assert dataset.score_answer(answer="invalid json", entry=item) == 0.01
assert dataset.score_answer(answer="invalid json", entry=item) == 0.0
config = GameOfLifeConfig(seed=43, size=1, grid_size_x=3, grid_size_y=3, filled_cells=1, simulation_steps=1)
dataset = GameOfLifeDataset(config)

View file

@ -121,16 +121,16 @@ def test_score_answer_cases():
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
# Incorrect but properly formatted
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0),
# Malformed expressions
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0),
# Empty answer
("", {"variable": "x", "integrand": "2*x"}, 0.01),
("", {"variable": "x", "integrand": "2*x"}, 0.0),
# Case sensitivity
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
# Alternative constant notation
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),

View file

@ -155,8 +155,8 @@ def test_score_calculation():
# Test invalid answers
assert dataset.score_answer(None, puzzle) == 0.0
assert dataset.score_answer("", puzzle) == 0.01
assert dataset.score_answer("Invalid", puzzle) == 0.01
assert dataset.score_answer("", puzzle) == 0.0
assert dataset.score_answer("Invalid", puzzle) == 0.0
# Test correct answer
assert dataset.score_answer(puzzle["answer"], puzzle) == 1.0

View file

@ -99,8 +99,7 @@ def test_score_answer():
assert dataset.score_answer(correct_answer, problem) == 1.0
assert abs(dataset.score_answer(half_answer, problem) - 0.65) < 1e-10
assert dataset.score_answer(modified_answer, problem) == 1.0
assert dataset.score_answer(wrong_answer, problem) == 0.01
print("flipped")
assert dataset.score_answer(wrong_answer, problem) == 0.0
assert dataset.score_answer(flipped_answer, problem) == 1.0

View file

@ -214,7 +214,7 @@ def test_manipulate_matrix_score_answer():
# incorrect answer
answer = "1 2 3\n4 5 6\n7 8 8"
assert dataset.score_answer(answer, entry) == 0.01
assert dataset.score_answer(answer, entry) == 0.0
# answer is none
answer = None

View file

@ -117,7 +117,7 @@ def test_nqueens_score_answer():
# Test invalid answer gets score 0.01
invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _"
assert dataset.score_answer(invalid_answer, item) == 0.01
assert dataset.score_answer(invalid_answer, item) == 0.0
# Test None answer gets score 0.0
assert dataset.score_answer(None, item) == 0.0

View file

@ -16,7 +16,7 @@ def test_needle_haystack():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.01
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
config = NeedleHaystackConfig(seed=42, size=1, num_statements=500)

View file

@ -110,7 +110,7 @@ def test_number_format_answer():
# Incorrect answer (diff larger than 1e-2)
model_answer = "54245.9"
assert dataset.score_answer(model_answer, entry) == 0.01
assert dataset.score_answer(model_answer, entry) == 0.0
# Answer is null
model_answer = None

View file

@ -84,8 +84,8 @@ def test_score_answer():
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
assert dataset.score_answer(wrong_letters, entry=item) == 0.02
# Empty String input should score 0.01
assert dataset.score_answer("", entry=item) == 0.01
# Empty String input should score 0.0
assert dataset.score_answer("", entry=item) == 0.0
# Empty input should score 0.0
assert dataset.score_answer(None, entry=item) == 0.0

View file

@ -95,17 +95,17 @@ def test_palindrome_partitioning_score_answer():
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 1
# Verify the score is 0.01 when incorrect
# Verify the score is 0.0 when incorrect
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0.01
assert dataset.score_answer(answer, item) == 0.0
# Verify the score is 0 when answer is None
# Verify the score is 0.0 when answer is None
answer = None
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0
assert dataset.score_answer(answer, item) == 0.0
# Verify the score is 0 when answer is malformed JSON
# Verify the score is 0.0 when answer is malformed JSON
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0
assert dataset.score_answer(answer, item) == 0.0

View file

@ -122,11 +122,11 @@ def test_polynomial_solutions_evaluation():
"oracle_answer, predicted_answer, expected_reward",
[
("4,-4.12", "4,-4.12", 1.0), # Exact match
("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match
("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)),
("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty
("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
("4,-4.12", "4.0001,-4.120001", approx(0.9994, rel=1e-3)), # Very close match
("4,-4.12", "4.1,-4.2", approx(0.4086, rel=1e-3)),
("4,8", "4", approx(0.5, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
("4", "4,8", approx(0.5, rel=1e-3)), # extra solution -> extra solution penalty
("-1,-2", "1,4", approx(1.0305e-9, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution
("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution
],

View file

@ -92,7 +92,6 @@ def test_polynomial_equations_dataset_items():
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
@ -127,42 +126,6 @@ def test_cross_polynomial_equations_dataset_items():
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
@ -197,7 +160,6 @@ def test_multivariate_polynomial_equations_dataset_items():
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
@ -242,7 +204,7 @@ def test_polynomial_solutions_evaluation():
poly_expr = sp.expand(poly_str)
# Verify that each solution satisfies the polynomial
assert poly_expr == item["answer"]
assert str(poly_expr) == item["answer"]
def test_score_function():
@ -266,11 +228,11 @@ def test_score_function():
for item in ds:
poly_str = item["metadata"]["polynomial_expr"]
assert ds.score_answer(poly_str, item) == 0.05
assert ds.score_answer(poly_str, item) == 0.0
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0
assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05
assert ds.score_answer(None, item) == 0.0
assert ds.score_answer("Not a polynomial", item) == 0.0
assert ds.score_answer("x**4", item) == 0.0

View file

@ -143,7 +143,7 @@ def test_pool_matrix_score_answer():
dataset = PoolMatrixDataset(config)
for entry in dataset:
assert dataset.score_answer(entry["answer"], entry=entry) == 1
assert 0.0 < dataset.score_answer("1 2.0\n3.0 4", entry=entry) <= 0.1
assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) in [0.0, 0.1]
assert dataset.score_answer("one two three", entry=entry) == 0.0
assert dataset.score_answer("", entry=entry) == 0.0
assert dataset.score_answer(None, entry=entry) == 0.0

View file

@ -71,7 +71,7 @@ def test_power_function_score_function():
# Answer is far from solution
answer = str(item["metadata"]["solution"] - 1)
assert dataset.score_answer(answer, item) == 0.01
assert dataset.score_answer(answer, item) == 0.0
# Answer is None
answer = None

View file

@ -139,7 +139,7 @@ def test_products_scoring():
assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0"
# Test scoring with wrong answer
assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01"
assert dataset.score_answer("wrong", item) == 0.0, "Wrong answer should score 0.0"
# Test scoring with partial match (answer contained in response)
assert (

View file

@ -100,4 +100,4 @@ def test_propositional_logic_dataset_score_answer_incorrect():
dataset = PropositionalLogicDataset(PropositionalLogicConfig(size=100, seed=101))
for i, item in enumerate(dataset):
score = dataset.score_answer("Wrong", item)
assert score == 0.01
assert score == 0.0

View file

@ -43,7 +43,8 @@ def test_quantumlock_items():
assert "target_value" in item["metadata"]
# Verify solution works
assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0
answer = "".join(item["metadata"]["solution_path"])
assert dataset.score_answer(answer=answer, entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
@ -98,17 +99,17 @@ def test_quantumlock_scoring():
dataset = QuantumLockDataset(config)
for item in dataset:
solution = item["metadata"]["solution_path"]
solution = item["answer"]
# Test correct solution
assert dataset.score_answer(solution, item) == 1.0
# Test empty/None answers
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.1
assert dataset.score_answer("", item) == 0.0
# Test invalid buttons
assert dataset.score_answer("XYZ", item) == 0.1
assert dataset.score_answer("XYZ", item) == 0.0
# Test case insensitivity
if solution:

View file

@ -86,7 +86,7 @@ def test_group_anagrams_dataset_items():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="gibberish", entry=item) == 0.01
assert dataset.score_answer(answer="gibberish", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

View file

@ -80,7 +80,7 @@ def test_rearc_scoring_edge_cases():
assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0
# Malformed answer
assert dataset.score_answer("[[invalid", entry=item) == 0.01
assert dataset.score_answer("[[invalid", entry=item) == 0.0
# Case sensitivity
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()

View file

@ -18,8 +18,8 @@ def test_self_reference():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=99, entry=item) == 0.1
assert dataset.score_answer(answer="99", entry=item) == 0.1
assert dataset.score_answer(answer=99, entry=item) == 0.0
assert dataset.score_answer(answer="99", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# # Medium
@ -34,8 +34,8 @@ def test_self_reference():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=99, entry=item) == 0.1
assert dataset.score_answer(answer="99", entry=item) == 0.1
assert dataset.score_answer(answer=99, entry=item) == 0.0
assert dataset.score_answer(answer="99", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# # Hard
@ -50,6 +50,6 @@ def test_self_reference():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=99, entry=item) == 0.1
assert dataset.score_answer(answer="99", entry=item) == 0.1
assert dataset.score_answer(answer=99, entry=item) == 0.0
assert dataset.score_answer(answer="99", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

View file

@ -164,7 +164,7 @@ def test_shortest_path_answer():
]
},
}
assert dataset.score_answer("right down right down", entry) == 0.01
assert dataset.score_answer("right down right down", entry) == 0.0
# Answer is None
entry = {

View file

@ -94,16 +94,16 @@ def test_score_answer_cases():
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
# Incorrect but properly formatted
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0),
# Malformed expressions
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0),
# Empty answer
("", {"variable": "x", "integrand": "2*x"}, 0.01),
("", {"variable": "x", "integrand": "2*x"}, 0.0),
# Case sensitivity
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
# Alternative constant notation
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),

View file

@ -6,6 +6,10 @@ from reasoning_gym.games.sokoban import SokobanConfig, SokobanDataset
def test_sokoban():
"""Test basic properties and solution of generated items"""
dataset = SokobanDataset(SokobanConfig(size=10, seed=1234))
for i, item in enumerate(dataset):
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
# Easy
config = SokobanConfig(seed=42, size=20)
dataset = SokobanDataset(config)
@ -18,25 +22,29 @@ def test_sokoban():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="RU", entry=item) == 0.1
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Medium
config = SokobanConfig(seed=42, min_h=40, max_h=50, min_w=40, max_w=50, min_boxes=20, max_boxes=30, size=3)
dataset = SokobanDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="RU", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Hard
config = SokobanConfig(seed=42, min_h=400, max_h=500, min_w=400, max_w=500, min_boxes=50, max_boxes=50, size=1)
config = SokobanConfig(
seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90
)
dataset = SokobanDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# min == max ranges
config = SokobanConfig(
seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60
)
dataset = SokobanDataset(config)
for item in dataset:

Some files were not shown because too many files have changed in this diff Show more