mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Minor question template & score_answer improvements (#261)
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
This commit is contained in:
parent
061282e373
commit
5d7fbac0ad
106 changed files with 403 additions and 507 deletions
|
|
@ -103,6 +103,10 @@ class ComplexArithmeticDataset(ProceduralDataset):
|
|||
# Normalize the answer string by removing spaces and converting to lowercase
|
||||
answer = answer.replace(" ", "").lower()
|
||||
|
||||
# remove brackets
|
||||
while len(answer) > 1 and answer[0] == "(" and answer[-1] == ")":
|
||||
answer = answer[1:-1]
|
||||
|
||||
# Convert mathematical notation 'i' to Python's 'j' for complex numbers
|
||||
answer = answer.replace("i", "j")
|
||||
|
||||
|
|
|
|||
|
|
@ -77,9 +77,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
|||
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
||||
]
|
||||
self.added_instruction = """
|
||||
In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
|
||||
When performing calculations, please follow these guidelines:
|
||||
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
|
||||
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
|
||||
3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number).
|
||||
"""
|
||||
|
||||
def _get_outer_constant(self, rng: random.Random) -> int:
|
||||
|
|
@ -245,7 +246,7 @@ In addition, when doing calculation, use the following instructions together wit
|
|||
"""Determine if the solution provided solves the problem"""
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
var = metadata["variable"]
|
||||
x = sympy.Symbol(var)
|
||||
|
|
@ -258,12 +259,8 @@ In addition, when doing calculation, use the following instructions together wit
|
|||
# Check mathematical equivalence through simplification
|
||||
if sympy.simplify(derivative - integrand) == 0:
|
||||
reward = 1.0
|
||||
elif answer.strip():
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,9 @@ class PolynomialEquationsConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
# reward function hyperparameters
|
||||
penalty_missing_factor = 0.1
|
||||
penalty_extra_factor = 0.05
|
||||
penalty_missing_factor = 0.5
|
||||
penalty_extra_factor = 0.5
|
||||
exp_distance_factor = -10.0
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
|
|
@ -62,12 +63,15 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
|
||||
]
|
||||
self.added_instruction = """
|
||||
In solving the equations, please abide by the following instruction:
|
||||
## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc.
|
||||
## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005".
|
||||
## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "".
|
||||
## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers.
|
||||
## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p.
|
||||
In solving equations, please follow these instructions:
|
||||
1. Provide all answers as comma-separated decimal values. For example: "-0.3773, 0.4005"
|
||||
2. For solutions that can be expressed in exact form (like "u = 2 + sqrt(4560)/172" and "u = 2 - sqrt(4560)/172"), convert them to decimal form in your final answer.
|
||||
3. If there are no real values that satisfy the equation, report your answer as an empty string: ""
|
||||
4. Format your answer based on the number of solutions:
|
||||
- For 1 solution: a single decimal number
|
||||
- For 2 solutions: two comma-separated decimal numbers
|
||||
- For 3 or more solutions: all values as comma-separated decimal numbers
|
||||
5. Round all decimal values to 4 decimal places (rounding down when the 5th decimal place is 5 or greater).
|
||||
"""
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
|
|
@ -238,7 +242,7 @@ In solving the equations, please abide by the following instruction:
|
|||
# Remove matched oracle solution
|
||||
oracle_solutions.pop(matched_distance_index)
|
||||
# Exponential decay reward
|
||||
total_reward += math.exp(-matched_distance)
|
||||
total_reward += math.exp(matched_distance * self.config.exp_distance_factor)
|
||||
else:
|
||||
# Extra predicted solution
|
||||
extra_solutions += 1
|
||||
|
|
|
|||
|
|
@ -69,9 +69,9 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
|||
"Calculate the following: {polynomial_expr}",
|
||||
]
|
||||
self.added_instruction = """
|
||||
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers.
|
||||
When performing calculations, please follow these guidelines:
|
||||
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
|
||||
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
|
||||
"""
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
|
|
@ -106,10 +106,9 @@ In addition, When doing calculation, Use the following instructions together wit
|
|||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": product,
|
||||
"answer": str(product),
|
||||
"metadata": {
|
||||
"polynomial_expr": str(polynomial_expr),
|
||||
"result": str(product),
|
||||
"variables": list(product.free_symbols),
|
||||
},
|
||||
}
|
||||
|
|
@ -147,21 +146,16 @@ In addition, When doing calculation, Use the following instructions together wit
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
try:
|
||||
predicted_poly = sp.parse_expr(answer)
|
||||
target_poly = sp.parse_expr(metadata["result"])
|
||||
target_poly = sp.parse_expr(entry["answer"])
|
||||
|
||||
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
||||
if predicted_poly == target_poly:
|
||||
reward = 1.0
|
||||
elif answer.strip():
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ class SimpleIntegrationDataset(ProceduralDataset):
|
|||
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
||||
]
|
||||
self.added_instruction = """
|
||||
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
|
||||
When performing calculations, please follow these guidelines:
|
||||
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
|
||||
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
|
||||
"""
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
|
|
@ -103,12 +103,8 @@ In addition, When doing calculation, Use the following instructions together wit
|
|||
# Check mathematical equivalence through simplification
|
||||
if sympy.simplify(derivative - integrand) == 0:
|
||||
reward = 1.0
|
||||
elif answer.strip():
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -130,12 +130,9 @@ Return the final state of the program.
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer != entry["answer"]:
|
||||
return 0.01
|
||||
else:
|
||||
if answer == entry["answer"]:
|
||||
return 1.0 # Yay
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -108,9 +108,9 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
# check if answer is python list of lists
|
||||
answer = self._matrix_to_str(eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
correct_mapping = {}
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
ans_arr = json.loads(answer)
|
||||
correct_arr = json.loads(entry["answer"])
|
||||
except Exception:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
total_cells = 0
|
||||
correct_cells = 0
|
||||
|
|
|
|||
|
|
@ -228,12 +228,13 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
|
|||
try:
|
||||
danswer = json.loads(answer)
|
||||
solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer)
|
||||
if not solved:
|
||||
return 0.01 # json was parsable but solution incorrect
|
||||
else:
|
||||
if solved:
|
||||
return 1.0 # Yay
|
||||
else:
|
||||
return 0.01 # json parsable
|
||||
except Exception:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
if answer_str == oracle_str:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
reward = 0.01 # json parsable
|
||||
except Exception:
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
|
|
|||
|
|
@ -303,11 +303,11 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
|
|||
danswer = json.loads(answer)
|
||||
valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer)
|
||||
if not valid:
|
||||
return 0.01
|
||||
return 0.01 # json parsable
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("jugs", JugsDataset, JugsConfig)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
|
||||
# Each word in the expected answer is worth an equal fraction of 1.0
|
||||
total_words = len(expected_words)
|
||||
score_per_word = 1.0 / total_words if total_words else 0
|
||||
score_per_word = 1.0 / total_words if total_words > 0 else 0
|
||||
|
||||
# Calculate scores word by word
|
||||
scores = []
|
||||
|
|
@ -142,18 +142,16 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"].strip().lower()
|
||||
if answer:
|
||||
answer = answer.strip().lower()
|
||||
if answer == oracle_answer:
|
||||
return 1.0 # Perfect score!
|
||||
else:
|
||||
partial_score = self.partial(oracle_answer, answer)
|
||||
return partial_score
|
||||
return 0.01
|
||||
answer = answer.strip().lower()
|
||||
if answer == oracle_answer:
|
||||
return 1.0 # Perfect score!
|
||||
else:
|
||||
partial_score = self.partial(oracle_answer, answer)
|
||||
return partial_score
|
||||
|
||||
|
||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
||||
|
|
|
|||
|
|
@ -144,8 +144,6 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
|
||||
if oracle_answer in answer:
|
||||
return len(oracle_answer) / len(answer)
|
||||
else:
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -92,14 +92,14 @@ class PalindromeDataset(ProceduralDataset):
|
|||
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
|
||||
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
|
||||
- An answer that is a string, but not a palindrome gives 0.02
|
||||
- An empty string gives 0.01.
|
||||
- An empty string gives 0.0
|
||||
- None gives 0.0.
|
||||
"""
|
||||
if answer is None or not isinstance(answer, str):
|
||||
return 0.0 # No answer given
|
||||
|
||||
if answer == "":
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
metadata = entry["metadata"]
|
||||
answer = answer.strip().lower()
|
||||
|
|
|
|||
|
|
@ -95,9 +95,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
|
||||
if answer == oracle:
|
||||
return 1.0
|
||||
return 0.01
|
||||
except Exception:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer based on the metadata"""
|
||||
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
reward = 0.0
|
||||
|
|
@ -91,8 +91,6 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
reward = 1.0
|
||||
elif oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
pass
|
||||
return reward
|
||||
|
|
|
|||
|
|
@ -108,14 +108,12 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if isinstance(answer, str):
|
||||
s_answer = answer.strip()
|
||||
if s_answer == str(entry["answer"]):
|
||||
return 1.0
|
||||
|
||||
s_answer = answer.strip()
|
||||
if not s_answer == str(entry["answer"]):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class SentenceReorderingDataset(ProceduralDataset):
|
|||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,14 +52,14 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
expected_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
if expected_answer.lower() == answer.lower():
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -126,11 +126,9 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
try:
|
||||
answer = " ".join(str(item) for item in eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
else:
|
||||
|
|
@ -83,9 +83,9 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
# check if answer is python list of characters
|
||||
answer = "".join(eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -221,8 +221,8 @@ class WordLadderDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
answer_words = tuple(s.strip() for s in answer.upper().split(","))
|
||||
|
||||
|
|
@ -239,17 +239,17 @@ class WordLadderDataset(ProceduralDataset):
|
|||
# 4. all words are in our vocabulary
|
||||
|
||||
if len(answer_words) < 2:
|
||||
return 0
|
||||
return 0.0
|
||||
|
||||
if answer_words[0] != start_word or answer_words[-1] != end_word:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
if not all(len(w) == word_length for w in answer_words):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
for i in range(1, len(answer_words)):
|
||||
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
reward = 1.0
|
||||
for word in answer_words:
|
||||
|
|
|
|||
|
|
@ -121,8 +121,6 @@ class WordSortingDataset(ProceduralDataset):
|
|||
return 1.0
|
||||
elif sorted(parsed_answer) == oracle_answer:
|
||||
return 0.2
|
||||
else:
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class ReArcDataset(ProceduralDataset):
|
|||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -160,17 +160,15 @@ class BitwiseArithmeticDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0.
|
||||
"""
|
||||
if answer is None:
|
||||
return 0.0
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
solved = verify_solution(entry["metadata"]["problem"], answer)
|
||||
if solved:
|
||||
return 1.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
solved = verify_solution(entry["metadata"]["problem"], answer)
|
||||
if solved:
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.01
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
|
|
|
|||
|
|
@ -428,7 +428,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
# we suppose the answer is the last occurence of the expected answer type
|
||||
if answer is None:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
|
|
@ -439,9 +439,6 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
}:
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
answer = answer.strip()
|
||||
oracle_answer = oracle_answer
|
||||
weekdays = {d.name.title() for d in Weekday}
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
+ problem_str
|
||||
)
|
||||
|
||||
return {"question": problem_str, "answer": answer, "metadata": {}}
|
||||
return {"question": problem_str, "answer": str(answer), "metadata": {}}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""
|
||||
|
|
@ -189,12 +189,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
|
||||
"""
|
||||
if answer is None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
user_ans: Decimal = Decimal(answer)
|
||||
correct_ans: Decimal = entry["answer"]
|
||||
correct_ans: Decimal = Decimal(entry["answer"])
|
||||
|
||||
# Determine tolerance based on the desired precision.
|
||||
precision: int = self.config.max_num_decimal_places
|
||||
|
|
@ -202,9 +202,9 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
if abs(user_ans - correct_ans) <= tol:
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.01
|
||||
pass
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
result -= c
|
||||
|
||||
expression = " ".join(expression_parts)
|
||||
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
|
||||
try:
|
||||
q = Decimal(f"0.{'0' * max(decimal_places)}")
|
||||
result = result.quantize(q)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
Returns:
|
||||
1.0 for exact numerical match, 0.01 otherwise
|
||||
"""
|
||||
if answer is None or len(answer.strip()) == 0:
|
||||
if not isinstance(answer, str) or len(answer.strip()) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
student_answer = Decimal(answer.strip())
|
||||
oracle_answer = Decimal(entry["answer"])
|
||||
|
||||
return 1.0 if student_answer == oracle_answer else 0.01
|
||||
except (ValueError, TypeError, ArithmeticError):
|
||||
return 0.01
|
||||
if student_answer == oracle_answer:
|
||||
return 1.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)
|
||||
|
|
|
|||
|
|
@ -138,12 +138,11 @@ class DiceDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
if isinstance(answer, str):
|
||||
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
|
||||
return 1.0 # Yay
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig)
|
||||
|
|
|
|||
|
|
@ -65,14 +65,13 @@ class NumberFormatDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["metadata"]["solution"]
|
||||
if answer is not None and len(answer) > 0:
|
||||
if isinstance(answer, str) and len(answer) > 0:
|
||||
try:
|
||||
answer = float(answer.strip().replace(",", ""))
|
||||
if abs(answer - oracle_answer) < 1e-2:
|
||||
return 1.0
|
||||
return 0.01
|
||||
except:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -44,10 +44,8 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
return 1.0
|
||||
elif difference < 1e-1:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class TimeIntervalsDataset(ProceduralDataset):
|
|||
Returns a score between 0 and 1, with partial credit for answers that are
|
||||
close to correct in the appropriate units/format
|
||||
"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
expected = entry["answer"]
|
||||
|
|
|
|||
|
|
@ -121,20 +121,23 @@ int main() {{
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
if answer != entry["answer"]:
|
||||
if entry["answer"] in answer.splitlines():
|
||||
# We can be quite confident that the correct answer was given
|
||||
# It was likely just given alongside an explanation
|
||||
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
|
||||
if entry["answer"] in answer:
|
||||
# Since answers are English words, some risk of the response coincidentally containing the answer
|
||||
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
|
||||
return 0.01
|
||||
else:
|
||||
|
||||
if answer == entry["answer"]:
|
||||
return 1.0 # Yay
|
||||
|
||||
if entry["answer"] in answer.splitlines():
|
||||
# We can be quite confident that the correct answer was given
|
||||
# It was likely just given alongside an explanation
|
||||
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
|
||||
|
||||
if entry["answer"] in answer:
|
||||
# Since answers are English words, some risk of the response coincidentally containing the answer
|
||||
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("bf", BFDataset, BFConfig)
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ class FigletFontDataset(ProceduralDataset):
|
|||
"""
|
||||
|
||||
correct_word = entry["answer"]
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0 # No answer given
|
||||
|
||||
# Normalize case
|
||||
|
|
|
|||
|
|
@ -110,19 +110,17 @@ class NeedleHaystackDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
if isinstance(answer, str):
|
||||
correct_word = entry["answer"]
|
||||
|
||||
correct_word = entry["answer"]
|
||||
if not answer:
|
||||
return 0.0 # No answer given
|
||||
# Normalize case
|
||||
answer = answer.replace(" ", "").strip().lower()
|
||||
correct_word = correct_word.strip().lower()
|
||||
|
||||
# Normalize case
|
||||
answer = answer.replace(" ", "").strip().lower()
|
||||
correct_word = correct_word.strip().lower()
|
||||
if answer == correct_word:
|
||||
return 1.0 # Correct!
|
||||
|
||||
if answer == correct_word:
|
||||
return 1.0 # Correct!
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -132,12 +132,10 @@ class RectangleCountDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
if isinstance(answer, str):
|
||||
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
|
||||
return 1.0 # Yay
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)
|
||||
|
|
|
|||
|
|
@ -69,9 +69,6 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
|
|||
reward = 1.0
|
||||
elif oracle_answer in answer:
|
||||
reward = len(oracle_answer) / len(answer)
|
||||
else:
|
||||
reward = 0.01
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
+ + + + + + +
|
||||
+ - * - - - +
|
||||
+ - - - $ - +
|
||||
+ X - - @ - +
|
||||
+ - - - - - +
|
||||
+ $ - + - - +
|
||||
+ + - - - - +
|
||||
+ X @ - $ - +
|
||||
+ + - - - - +
|
||||
+ + + + + + +
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
+ + + + + + +
|
||||
+ * - @ - X +
|
||||
+ + - @ - + +
|
||||
+ X - - - - +
|
||||
+ + + + + + +
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
- - + + + + + +
|
||||
- + + - - - * +
|
||||
+ + - - - + X +
|
||||
+ X - @ - @ @ +
|
||||
+ X X @ - - - +
|
||||
+ + + + + + + +
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
- + + + + + + - - -
|
||||
- + X - - X + - - -
|
||||
+ + - @ @ + + - - -
|
||||
+ - - - - + + - - -
|
||||
+ - @ - - * + + + +
|
||||
+ + - - - - - - X +
|
||||
- + + + + + + + + +
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
- + + + + + + - -
|
||||
+ + X - @ - + + +
|
||||
+ - - - - - - - +
|
||||
+ - @ + + X - @ +
|
||||
+ - - - @ - + - +
|
||||
+ + + * - X - X +
|
||||
- - + + + + + + +
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
- + + + + + + + -
|
||||
+ + - - + - - + +
|
||||
+ - @ - - - @ - +
|
||||
+ - - X * X - - +
|
||||
+ + @ + + - - + +
|
||||
+ - - X - - - + -
|
||||
+ + + + + + + + -
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
- - - + + + + + + + +
|
||||
- - - + - - - - - - +
|
||||
- - + + - - - - @ - +
|
||||
- + + - - + + - + + +
|
||||
+ + - - + - - X - - +
|
||||
+ - - + X @ @ - - + +
|
||||
+ * + X - - - - + + -
|
||||
+ + - - - - - + + - -
|
||||
+ + + + + + + + - - -
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
+ + + + + + + +
|
||||
+ - - @ - X * +
|
||||
+ - @ - - + X +
|
||||
+ X X @ - @ @ +
|
||||
+ X X @ - - - +
|
||||
+ + + + + + + +
|
||||
|
|
@ -13,7 +13,7 @@ from reasoning_gym.games.contrib.sokoban.src.utils import (
|
|||
)
|
||||
|
||||
|
||||
def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
|
||||
def astar(matrix, player_pos, debug: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
|
||||
# print(f'A* - {heuristic.title()} Heuristic')
|
||||
heur = "[A*]" if heuristic == "manhattan" else "[Dijkstra]"
|
||||
shape = matrix.shape
|
||||
|
|
@ -67,15 +67,18 @@ def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
|
|||
return (path + direction[move], depth + 1)
|
||||
if debug:
|
||||
print(f"{heur} Solution Depth: {depth + 1}\n{path + direction[move]}", 20)
|
||||
print(f"{heur} Solution not found!\n")
|
||||
|
||||
if depth > max_depth:
|
||||
break
|
||||
|
||||
if debug:
|
||||
print(f"{heur} Solution Not Found!\nDepth {depth + 1}", 20)
|
||||
|
||||
return (None, -1 if not heap else depth + 1)
|
||||
|
||||
|
||||
def solve_astar(puzzle, visualizer=False, heuristic="manhattan"):
|
||||
def solve_astar(puzzle, visualizer: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
|
||||
matrix = puzzle
|
||||
where = np.where((matrix == "*") | (matrix == "%"))
|
||||
player_pos = where[0][0], where[1][0]
|
||||
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic)
|
||||
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic, max_depth=max_depth)
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ class PuzzleElement:
|
|||
|
||||
|
||||
class Game:
|
||||
def __init__(self, width=19, height=10, level=None, path=None):
|
||||
self.level = level
|
||||
def __init__(self, width=19, height=10, path=None):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.puzzle = np.empty((height, width), dtype=PuzzleElement)
|
||||
|
|
@ -39,7 +38,7 @@ class Game:
|
|||
self.puzzle_size = None
|
||||
self.pad_x = 0
|
||||
self.pad_y = 0
|
||||
self.path = path or f"levels/lvl{level}.dat"
|
||||
self.path = path
|
||||
|
||||
if path:
|
||||
if type(self) == Game:
|
||||
|
|
@ -108,7 +107,7 @@ class Game:
|
|||
|
||||
# Calculate puzzle size and padding
|
||||
self.puzzle_size = (len(data), len(data[0]) if len(data) > 0 else 0)
|
||||
pad_x = (self.width - self.puzzle_size[1] - 2) // 2 # -2 matches original file-based logic
|
||||
pad_x = (self.width - self.puzzle_size[1]) // 2
|
||||
pad_y = (self.height - self.puzzle_size[0]) // 2
|
||||
self.pad_x, self.pad_y = pad_x, pad_y
|
||||
|
||||
|
|
@ -140,15 +139,15 @@ class Game:
|
|||
|
||||
|
||||
class ReverseGame(Game):
|
||||
def __init__(self, rng: Random, width=19, height=10, level=None):
|
||||
super().__init__(width, height, level)
|
||||
def __init__(self, rng: Random, width: int = 19, height: int = 10):
|
||||
super().__init__(width, height)
|
||||
self.rng = rng
|
||||
self.pad_x = 0
|
||||
self.pad_y = 0
|
||||
|
||||
def load_puzzle(self, puzzle):
|
||||
self.puzzle_size = (len(puzzle), len(puzzle[0]) if len(puzzle) > 0 else 0)
|
||||
pad_x = (self.width - len(puzzle[0]) - 2) // 2
|
||||
pad_x = (self.width - len(puzzle[0])) // 2
|
||||
pad_y = (self.height - len(puzzle)) // 2
|
||||
self.pad_x, self.pad_y = pad_x, pad_y
|
||||
for i, row in enumerate(puzzle):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ from reasoning_gym.games.contrib.sokoban.src.game import Game, ReverseGame
|
|||
|
||||
|
||||
def num_boxes(puzzle_area, min_boxes, max_boxes, min_w, min_h, max_w, max_h):
|
||||
if min_w == max_w or min_h == max_h or min_boxes == max_boxes:
|
||||
return max_boxes
|
||||
|
||||
m = (max_boxes - min_boxes) / (max_w * max_h - min_w * min_h)
|
||||
b = min_boxes - m * min_w * min_h
|
||||
return int(m * puzzle_area + b)
|
||||
|
|
@ -19,31 +22,33 @@ def random_valid(rng: Random, width: int = 10, height: int = 10):
|
|||
def generate(
|
||||
rng: Random,
|
||||
debug: bool = False,
|
||||
path: str = None,
|
||||
min_w: int = 6,
|
||||
min_h: int = 6,
|
||||
max_w: int = 15,
|
||||
max_h: int = 10,
|
||||
min_boxes: int = 4,
|
||||
max_boxes: int = 10,
|
||||
max_depth: int = 100,
|
||||
path: str = None,
|
||||
) -> tuple[str, str, dict]:
|
||||
"""
|
||||
Generates a level with the given configuration parameters.
|
||||
|
||||
Parameters:
|
||||
rng: Random number generator for reproducibility.
|
||||
visualizer: Whether to visualize the generation process.
|
||||
path: Path to save the level file (default 'levels/lvl0.dat').
|
||||
min_w: Minimum width of the puzzle.
|
||||
min_h: Minimum height of the puzzle.
|
||||
max_w: Maximum width of the puzzle.
|
||||
max_h: Maximum height of the puzzle.
|
||||
min_boxes: Minimum number of boxes.
|
||||
max_boxes: Maximum number of boxes.
|
||||
rng: Random number generator
|
||||
visualizer: Whether to visualize the generation process
|
||||
min_w: Minimum width of the puzzle
|
||||
min_h: Minimum height of the puzzle
|
||||
max_w: Maximum width of the puzzle
|
||||
max_h: Maximum height of the puzzle
|
||||
min_boxes: Minimum number of boxes
|
||||
max_boxes: Maximum number of boxes
|
||||
max_depth: Maximum search depth
|
||||
path: Path to save the level file (optional)
|
||||
Returns:
|
||||
puzzle_string, solution
|
||||
"""
|
||||
path = path or "levels/lvl0.dat"
|
||||
|
||||
while True:
|
||||
width = rng.randint(min_w, max_w)
|
||||
height = rng.randint(min_h, max_h)
|
||||
|
|
@ -60,7 +65,7 @@ def generate(
|
|||
puzzle[box_pos] = "$"
|
||||
boxes_created += 1
|
||||
boxes_seen.add(box_pos)
|
||||
reverse_game = ReverseGame(rng=rng, level=0)
|
||||
reverse_game = ReverseGame(rng=rng, width=width, height=height)
|
||||
reverse_game.load_puzzle(puzzle)
|
||||
player = reverse_game.player
|
||||
counter = round(height * width * rng.uniform(1.8, 3.6))
|
||||
|
|
@ -79,16 +84,19 @@ def generate(
|
|||
out_of_place_boxes = np.sum([str(x) == "@" for x in matrix.flatten()])
|
||||
if out_of_place_boxes >= boxes // 2:
|
||||
# Optionally save the puzzle to a file:
|
||||
# np.savetxt(path, matrix, fmt='%s')
|
||||
if path:
|
||||
np.savetxt(path, matrix, fmt="%s")
|
||||
puzzle_str = player.puzzle_to_string(matrix)
|
||||
|
||||
grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")]
|
||||
grid_array = np.array(grid_list)
|
||||
solution, _ = solve_astar(grid_array)
|
||||
solution, depth = solve_astar(grid_array, max_depth=max_depth)
|
||||
if solution is None:
|
||||
continue # retry generation
|
||||
|
||||
if debug:
|
||||
print(f"solution={solution}")
|
||||
game = Game()
|
||||
game = Game(width=width, height=height)
|
||||
game.load_puzzle_matrix(grid_array)
|
||||
|
||||
for step, move in enumerate(solution):
|
||||
|
|
|
|||
|
|
@ -618,7 +618,7 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
return grid
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
|
|
|
|||
|
|
@ -314,36 +314,35 @@ class KnightSwapDataset(ProceduralDataset):
|
|||
- 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length)
|
||||
- A proportional score for correct but longer solutions
|
||||
- 0.05 for valid moves that don't solve the puzzle
|
||||
- 0.01 for invalid format
|
||||
- 0.0 for None
|
||||
- 0.0 for invalid format or None
|
||||
"""
|
||||
if answer is None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
answer = answer.strip()
|
||||
if not answer:
|
||||
return 0.01
|
||||
if len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
# Handle impossible puzzles
|
||||
if not entry["metadata"]["is_possible"]:
|
||||
return 1.0 if answer.lower() == "no" else 0.01
|
||||
return 1.0 if answer.lower() == "no" else 0.0
|
||||
|
||||
# Handle "No" answer for possible puzzles
|
||||
if answer.lower() == "no":
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Parse moves from JSON list
|
||||
move_list = json.loads(answer)
|
||||
if not isinstance(move_list, list):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
# Parse moves
|
||||
moves = []
|
||||
for move_str in move_list:
|
||||
color, start, end = move_str.split(",")
|
||||
if color not in ("w", "B"):
|
||||
return 0.01
|
||||
return 0.0
|
||||
moves.append((color, start, end))
|
||||
|
||||
# Validate and apply moves
|
||||
|
|
@ -357,13 +356,13 @@ class KnightSwapDataset(ProceduralDataset):
|
|||
|
||||
for color, start, end in moves:
|
||||
if color != current_turn:
|
||||
return 0.01
|
||||
return 0.0
|
||||
if start not in pieces or pieces[start] != color:
|
||||
return 0.01
|
||||
return 0.0
|
||||
if end not in board[start]:
|
||||
return 0.01
|
||||
return 0.0
|
||||
if end in pieces and pieces[end] is not None:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
# Apply move
|
||||
pieces[end] = pieces[start]
|
||||
|
|
@ -390,7 +389,7 @@ class KnightSwapDataset(ProceduralDataset):
|
|||
return 0.05
|
||||
|
||||
except Exception:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("knight_swap", KnightSwapDataset, KnightSwapConfig)
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class MiniSudokuDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if not answer:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
|
|
|
|||
|
|
@ -138,8 +138,8 @@ class NQueensDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
valid_solutions = entry["metadata"]["valid_answers"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
valid_solutions = entry["metadata"]["valid_answers"]
|
||||
if answer in valid_solutions:
|
||||
return 1.0
|
||||
try:
|
||||
|
|
@ -147,7 +147,7 @@ class NQueensDataset(ProceduralDataset):
|
|||
if answer in valid_solutions:
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class RushHourDataset(ProceduralDataset):
|
|||
Returns:
|
||||
1.0 if solution reaches goal state, 0.0 otherwise
|
||||
"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -11,20 +11,26 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
class SokobanConfig:
|
||||
"""Configuration for sokoban puzzle generation"""
|
||||
|
||||
min_w: int = 6 # Minimum width of the puzzle.
|
||||
min_h: int = 6 # Minimum height of the puzzle.
|
||||
max_w: int = 10 # Maximum width of the puzzle.
|
||||
max_h: int = 10 # Maximum height of the puzzle.
|
||||
min_boxes: int = 6 # Minimum number of boxes.
|
||||
max_boxes: int = 10 # Maximum number of boxes.
|
||||
min_w: int = 6 # Minimum width of the puzzle
|
||||
min_h: int = 6 # Minimum height of the puzzle
|
||||
max_w: int = 10 # Maximum width of the puzzle
|
||||
max_h: int = 10 # Maximum height of the puzzle
|
||||
min_boxes: int = 4 # Minimum number of boxes
|
||||
max_boxes: int = 10 # Maximum number of boxes
|
||||
max_depth: int = 80 # Maximum search depth
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 0 < self.max_w <= 20
|
||||
assert 0 < self.max_h <= 20
|
||||
assert self.min_h > 0
|
||||
assert self.min_w > 0
|
||||
assert self.min_w <= self.max_w, "min_w must be lte max_w"
|
||||
assert self.min_h <= self.max_h, "min_h must be lte max_h"
|
||||
assert self.min_boxes <= self.max_boxes, "min_boxes must be lte max_boxes"
|
||||
assert self.max_depth > 1
|
||||
|
||||
|
||||
class SokobanDataset(ProceduralDataset):
|
||||
|
|
@ -58,7 +64,16 @@ class SokobanDataset(ProceduralDataset):
|
|||
|
||||
# Make the Sokoban!
|
||||
rng = Random(self.seed + idx)
|
||||
gamestr, solution, difficulty = self._generate(rng=rng)
|
||||
gamestr, solution, difficulty = self._generate(
|
||||
rng=rng,
|
||||
min_w=self.config.min_w,
|
||||
min_h=self.config.min_h,
|
||||
max_w=self.config.max_w,
|
||||
max_h=self.config.max_h,
|
||||
min_boxes=self.config.min_boxes,
|
||||
max_boxes=self.config.max_boxes,
|
||||
max_depth=self.config.max_depth,
|
||||
)
|
||||
|
||||
return {
|
||||
"question": """You are going to solve a 'sokoban' puzzle.
|
||||
|
|
@ -93,14 +108,15 @@ Here is your puzzle:
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
grid_list = [list(line) for line in entry["metadata"]["gamestr"].replace(" ", "").strip().split("\n")]
|
||||
matrix = np.array(grid_list)
|
||||
|
||||
game = self._Game()
|
||||
h, w = matrix.shape
|
||||
game = self._Game(height=h, width=w)
|
||||
game.load_puzzle_matrix(matrix)
|
||||
|
||||
for move in answer:
|
||||
|
|
@ -108,10 +124,10 @@ Here is your puzzle:
|
|||
|
||||
if self._is_solved(game.get_curr_state()):
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
except:
|
||||
pass
|
||||
|
||||
return 0.1
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("sokoban", SokobanDataset, SokobanConfig)
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ class SudokuDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if not answer:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class HanoiDataset(ProceduralDataset):
|
|||
start_peg=peg_labels[start_peg],
|
||||
target_peg=peg_labels[target_peg],
|
||||
),
|
||||
"answer": solution,
|
||||
"answer": "\n".join(solution),
|
||||
"metadata": {
|
||||
"num_disks": num_disks,
|
||||
"num_pegs": num_pegs,
|
||||
|
|
@ -383,24 +383,14 @@ class HanoiDataset(ProceduralDataset):
|
|||
Expected behavior:
|
||||
- Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0.
|
||||
- A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count
|
||||
- A badly formatted answer gives a minimal reward (0.01).
|
||||
- An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05).
|
||||
- An empty string gives 0.01.
|
||||
- None gives 0.0.
|
||||
- A badly formatted or empty answer gives 0.0
|
||||
"""
|
||||
if answer is None:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
if answer == "":
|
||||
return 0.01
|
||||
|
||||
# If answer is a string, split it into lines; if it's already a list, use it directly.
|
||||
if isinstance(answer, str):
|
||||
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
|
||||
elif isinstance(answer, list):
|
||||
moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()]
|
||||
else:
|
||||
return 0.0
|
||||
# Spilt answer string it into lines
|
||||
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
|
||||
|
||||
# Build the initial peg state from metadata.
|
||||
metadata = entry["metadata"]
|
||||
|
|
@ -418,11 +408,11 @@ class HanoiDataset(ProceduralDataset):
|
|||
try:
|
||||
disk, from_peg, to_peg = self._parse_move(move)
|
||||
except Exception:
|
||||
return 0.01 # Invalid move format
|
||||
return 0.0 # Invalid move format
|
||||
|
||||
# Validate the move using existing _validate_move method.
|
||||
if not self._validate_move(peg_state, move):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
# Execute the move.
|
||||
peg_state[from_peg].pop()
|
||||
|
|
|
|||
|
|
@ -358,16 +358,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
answer_formatted = answer.strip().lower()
|
||||
solved = answer_formatted == entry["answer"].strip().lower()
|
||||
if solved:
|
||||
oracle_answer = entry["answer"].strip().lower()
|
||||
if answer_formatted == oracle_answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
pass
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -169,21 +169,15 @@ Buttons:
|
|||
|
||||
The function awards 1.0 for a correct answer and less otherwise.
|
||||
"""
|
||||
if answer == None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
# Get correct solution from metadata
|
||||
correct_solution = entry["metadata"].get("solution_path", [])
|
||||
|
||||
# Normalize both answers
|
||||
def normalize_seq(seq):
|
||||
"""Handle both string and list inputs by converting to string first"""
|
||||
# Convert sequence to string representation if it's a list
|
||||
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
|
||||
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
|
||||
def normalize_seq(seq: str) -> list[str]:
|
||||
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
|
||||
|
||||
user_sequence = normalize_seq(answer)
|
||||
target_sequence = normalize_seq("".join(correct_solution))
|
||||
target_sequence = normalize_seq(entry["answer"])
|
||||
|
||||
# Exact sequence match required
|
||||
if user_sequence == target_sequence:
|
||||
|
|
@ -196,7 +190,7 @@ Buttons:
|
|||
return 1.0 # Different answer, but qually correct
|
||||
return 0.5 # Alternative scoring - you're correct, but not optimal
|
||||
|
||||
return 0.1
|
||||
return 0.0
|
||||
|
||||
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
|
||||
"""Simulate button presses to verify solutions"""
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"].strip()
|
||||
if answer is not None and len(answer) > 0:
|
||||
if isinstance(answer, str) and len(answer) > 0:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
answer = answer.strip()
|
||||
|
||||
# Exact answer
|
||||
|
|
@ -145,8 +145,6 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
elif self._is_valid_path(matrix, answer):
|
||||
return 0.5
|
||||
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -401,16 +401,14 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None or len(answer) == 0:
|
||||
return 0.0
|
||||
if isinstance(answer, str) and len(answer) > 0:
|
||||
oracle_answer = entry["answer"]
|
||||
if oracle_answer == answer:
|
||||
return 1.0
|
||||
elif oracle_answer == answer.strip():
|
||||
return len(oracle_answer) / len(answer)
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
if oracle_answer == answer:
|
||||
return 1.0
|
||||
elif oracle_answer == answer.strip():
|
||||
return len(oracle_answer) / len(answer)
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
||||
|
|
|
|||
|
|
@ -489,7 +489,7 @@ class KnightsKnavesDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score an answer against the oracle answer."""
|
||||
if answer is None or len(answer) == 0:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
|
|
@ -506,11 +506,9 @@ class KnightsKnavesDataset(ProceduralDataset):
|
|||
if matching > 0:
|
||||
return 0.3 + (0.7 * matching / len(oracle_assignments))
|
||||
|
||||
return 0.01
|
||||
|
||||
except Exception:
|
||||
# If parsing fails, give minimal credit
|
||||
return 0.01
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig)
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float:
|
||||
"""Robust scoring implementation for propositional logic answers"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
|
|
@ -304,7 +304,7 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
valid_vars = set(entry["metadata"]["variables"])
|
||||
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
|
||||
if any(var not in valid_vars for var in answer_vars):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
|
||||
answer_expr = Expression.from_string(cleaned_answer)
|
||||
|
|
@ -316,7 +316,7 @@ class PropositionalLogicDataset(ProceduralDataset):
|
|||
return 1.0
|
||||
return 0.05
|
||||
except (ValueError, KeyError, AttributeError):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
def _is_trivial(self, expr: Expression) -> bool:
|
||||
"""Check for trivial tautologies like P ∨ ¬P"""
|
||||
|
|
|
|||
|
|
@ -339,9 +339,7 @@ class SelfReferenceDataset(ProceduralDataset):
|
|||
|
||||
# Solve puzzle
|
||||
solutions = solve_puzzle_dynamic(puzzle)
|
||||
for idx, sol in enumerate(solutions, start=1):
|
||||
sol_str = ["True" if s else "False" for s in sol]
|
||||
answer = len(solutions)
|
||||
answer = str(len(solutions))
|
||||
|
||||
return {
|
||||
"question": puzz_s,
|
||||
|
|
@ -362,12 +360,10 @@ class SelfReferenceDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if str(answer) != str(entry["answer"]):
|
||||
return 0.1
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
if isinstance(answer, str):
|
||||
if answer == str(entry["answer"]):
|
||||
return 1.0 # Yay
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig)
|
||||
|
|
|
|||
|
|
@ -68,12 +68,10 @@ class ZebraDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
if isinstance(answer, str):
|
||||
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
|
||||
return 1.0 # Yay
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,6 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
|
|||
"""
|
||||
reward = 0.0
|
||||
if answer is not None and len(answer) > 0:
|
||||
reward = 0.01
|
||||
try:
|
||||
if strip_commas:
|
||||
answer = answer.replace(",", "")
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def test_ab_scoring():
|
|||
|
||||
# Test wrong answer
|
||||
wrong_answer = "A# B#" if item["answer"] != "A# B#" else "B# A#"
|
||||
assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.01
|
||||
assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.0
|
||||
|
||||
# Test None answer
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def test_arc_1d_scoring():
|
|||
assert dataset.score_answer(f"The answer is: {entry['answer']}", entry) > 0.5
|
||||
|
||||
# Test incorrect answer
|
||||
assert dataset.score_answer("wrong answer", entry) == 0.01
|
||||
assert dataset.score_answer("wrong answer", entry) == 0.0
|
||||
|
||||
# Test None answer
|
||||
assert dataset.score_answer(None, entry) == 0.0
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ def test_arc_agi_scoring():
|
|||
assert dataset.score_answer(item["answer"], entry=item) == 1.0
|
||||
|
||||
# Test invalid format
|
||||
assert dataset.score_answer("invalid grid format", entry=item) == 0.01
|
||||
assert dataset.score_answer("invalid grid format", entry=item) == 0.0
|
||||
|
||||
# Test None answer
|
||||
assert dataset.score_answer(None, entry=item) == 0.0
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_bf():
|
|||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.01
|
||||
assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.0
|
||||
|
||||
# Medium
|
||||
config = BFConfig(seed=43, size=20, difficulty=2)
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def test_binary_matrix_answer():
|
|||
# Answer is a python list (partially correct answer)
|
||||
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
|
||||
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||
assert dataset.score_answer(answer, entry) == 0.5
|
||||
assert dataset.score_answer(answer, entry) == 0.1
|
||||
|
||||
# Answer is null
|
||||
answer = None
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def test_bitwise_arithmetic_items():
|
|||
|
||||
# Test scoring edge cases
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="invalid", entry=item) == 0.01
|
||||
assert dataset.score_answer(answer="invalid", entry=item) == 0.0
|
||||
|
||||
|
||||
def test_bitwise_arithmetic_difficulty_levels():
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ def test_complex_arithmetic_scoring():
|
|||
assert dataset.score_answer("3 + 2i", entry) == 1.0
|
||||
assert dataset.score_answer("3+2i", entry) == 1.0
|
||||
assert dataset.score_answer("3.0 + 2.0i", entry) == 1.0
|
||||
assert dataset.score_answer("((3.0 + 2.0i ) )", entry) == 1.0
|
||||
|
||||
# Test answers with small errors (should get high but < 1.0 scores)
|
||||
print(dataset.score_answer("3.1 + 2i", entry))
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def test_reseeding_dataset_iteration():
|
|||
|
||||
# Test score_answer forwarding
|
||||
test_item = next(iter(infinite_dataset))
|
||||
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
|
||||
assert infinite_dataset.score_answer("wrong", test_item) == 0.0
|
||||
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0
|
||||
|
||||
|
||||
|
|
|
|||
17
tests/test_dataset_common.py
Normal file
17
tests/test_dataset_common.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import reasoning_gym
|
||||
from reasoning_gym.factory import DATASETS
|
||||
|
||||
|
||||
def test_score_answer_consistency():
|
||||
for dataset_name in DATASETS.keys():
|
||||
if dataset_name == "composite":
|
||||
continue
|
||||
dataset = reasoning_gym.create_dataset(dataset_name, size=10, seed=1234)
|
||||
for entry in dataset:
|
||||
assert entry["answer"] is None or isinstance(
|
||||
entry["answer"], str
|
||||
), f"{dataset_name} answer must be str, is {type(entry['answer'])}"
|
||||
if entry["answer"] is not None:
|
||||
assert (
|
||||
dataset.score_answer(answer=entry["answer"], entry=entry) == 1.0
|
||||
), f"inconsistent score_answer {dataset_name}"
|
||||
|
|
@ -242,11 +242,11 @@ def test_decimal_precision_scoring():
|
|||
assert dataset.score_answer("0.3", {"answer": "0.300"}) == 1.0
|
||||
|
||||
# Test incorrect answers
|
||||
assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.01
|
||||
assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.01
|
||||
assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.0
|
||||
|
||||
# Test invalid inputs
|
||||
assert dataset.score_answer(None, {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("", {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.01
|
||||
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.01
|
||||
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def test_game_of_life_basic_properties():
|
|||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="invalid json", entry=item) == 0.01
|
||||
assert dataset.score_answer(answer="invalid json", entry=item) == 0.0
|
||||
|
||||
config = GameOfLifeConfig(seed=43, size=1, grid_size_x=3, grid_size_y=3, filled_cells=1, simulation_steps=1)
|
||||
dataset = GameOfLifeDataset(config)
|
||||
|
|
|
|||
|
|
@ -121,16 +121,16 @@ def test_score_answer_cases():
|
|||
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
|
||||
# Incorrect but properly formatted
|
||||
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
|
||||
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0),
|
||||
# Malformed expressions
|
||||
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
|
||||
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0),
|
||||
# Empty answer
|
||||
("", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
("", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
# Case sensitivity
|
||||
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
|
||||
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0),
|
||||
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
# Alternative constant notation
|
||||
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
|
||||
|
|
|
|||
|
|
@ -155,8 +155,8 @@ def test_score_calculation():
|
|||
|
||||
# Test invalid answers
|
||||
assert dataset.score_answer(None, puzzle) == 0.0
|
||||
assert dataset.score_answer("", puzzle) == 0.01
|
||||
assert dataset.score_answer("Invalid", puzzle) == 0.01
|
||||
assert dataset.score_answer("", puzzle) == 0.0
|
||||
assert dataset.score_answer("Invalid", puzzle) == 0.0
|
||||
|
||||
# Test correct answer
|
||||
assert dataset.score_answer(puzzle["answer"], puzzle) == 1.0
|
||||
|
|
|
|||
|
|
@ -99,8 +99,7 @@ def test_score_answer():
|
|||
assert dataset.score_answer(correct_answer, problem) == 1.0
|
||||
assert abs(dataset.score_answer(half_answer, problem) - 0.65) < 1e-10
|
||||
assert dataset.score_answer(modified_answer, problem) == 1.0
|
||||
assert dataset.score_answer(wrong_answer, problem) == 0.01
|
||||
print("flipped")
|
||||
assert dataset.score_answer(wrong_answer, problem) == 0.0
|
||||
assert dataset.score_answer(flipped_answer, problem) == 1.0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def test_manipulate_matrix_score_answer():
|
|||
|
||||
# incorrect answer
|
||||
answer = "1 2 3\n4 5 6\n7 8 8"
|
||||
assert dataset.score_answer(answer, entry) == 0.01
|
||||
assert dataset.score_answer(answer, entry) == 0.0
|
||||
|
||||
# answer is none
|
||||
answer = None
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def test_nqueens_score_answer():
|
|||
|
||||
# Test invalid answer gets score 0.01
|
||||
invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _"
|
||||
assert dataset.score_answer(invalid_answer, item) == 0.01
|
||||
assert dataset.score_answer(invalid_answer, item) == 0.0
|
||||
|
||||
# Test None answer gets score 0.0
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def test_needle_haystack():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.01
|
||||
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
config = NeedleHaystackConfig(seed=42, size=1, num_statements=500)
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ def test_number_format_answer():
|
|||
|
||||
# Incorrect answer (diff larger than 1e-2)
|
||||
model_answer = "54245.9"
|
||||
assert dataset.score_answer(model_answer, entry) == 0.01
|
||||
assert dataset.score_answer(model_answer, entry) == 0.0
|
||||
|
||||
# Answer is null
|
||||
model_answer = None
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ def test_score_answer():
|
|||
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
|
||||
assert dataset.score_answer(wrong_letters, entry=item) == 0.02
|
||||
|
||||
# Empty String input should score 0.01
|
||||
assert dataset.score_answer("", entry=item) == 0.01
|
||||
# Empty String input should score 0.0
|
||||
assert dataset.score_answer("", entry=item) == 0.0
|
||||
|
||||
# Empty input should score 0.0
|
||||
assert dataset.score_answer(None, entry=item) == 0.0
|
||||
|
|
|
|||
|
|
@ -95,17 +95,17 @@ def test_palindrome_partitioning_score_answer():
|
|||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 1
|
||||
|
||||
# Verify the score is 0.01 when incorrect
|
||||
# Verify the score is 0.0 when incorrect
|
||||
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0.01
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
||||
# Verify the score is 0 when answer is None
|
||||
# Verify the score is 0.0 when answer is None
|
||||
answer = None
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
||||
# Verify the score is 0 when answer is malformed JSON
|
||||
# Verify the score is 0.0 when answer is malformed JSON
|
||||
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
|
||||
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||
assert dataset.score_answer(answer, item) == 0
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
|
|
|||
|
|
@ -122,11 +122,11 @@ def test_polynomial_solutions_evaluation():
|
|||
"oracle_answer, predicted_answer, expected_reward",
|
||||
[
|
||||
("4,-4.12", "4,-4.12", 1.0), # Exact match
|
||||
("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match
|
||||
("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)),
|
||||
("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
|
||||
("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty
|
||||
("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
|
||||
("4,-4.12", "4.0001,-4.120001", approx(0.9994, rel=1e-3)), # Very close match
|
||||
("4,-4.12", "4.1,-4.2", approx(0.4086, rel=1e-3)),
|
||||
("4,8", "4", approx(0.5, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
|
||||
("4", "4,8", approx(0.5, rel=1e-3)), # extra solution -> extra solution penalty
|
||||
("-1,-2", "1,4", approx(1.0305e-9, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
|
||||
("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution
|
||||
("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution
|
||||
],
|
||||
|
|
|
|||
|
|
@ -92,7 +92,6 @@ def test_polynomial_equations_dataset_items():
|
|||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
|
|
@ -127,42 +126,6 @@ def test_cross_polynomial_equations_dataset_items():
|
|||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sp.sympify(poly_str)
|
||||
|
||||
|
||||
def test_cross_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple("xyz"),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=False,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
|
|
@ -197,7 +160,6 @@ def test_multivariate_polynomial_equations_dataset_items():
|
|||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["result"], str)
|
||||
assert isinstance(item["metadata"]["variables"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
|
|
@ -242,7 +204,7 @@ def test_polynomial_solutions_evaluation():
|
|||
poly_expr = sp.expand(poly_str)
|
||||
|
||||
# Verify that each solution satisfies the polynomial
|
||||
assert poly_expr == item["answer"]
|
||||
assert str(poly_expr) == item["answer"]
|
||||
|
||||
|
||||
def test_score_function():
|
||||
|
|
@ -266,11 +228,11 @@ def test_score_function():
|
|||
|
||||
for item in ds:
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
assert ds.score_answer(poly_str, item) == 0.05
|
||||
assert ds.score_answer(poly_str, item) == 0.0
|
||||
|
||||
poly_expr = str(sp.expand(poly_str))
|
||||
assert ds.score_answer(poly_expr, item) == 1.0
|
||||
|
||||
assert ds.score_answer(None, item) == 0.00
|
||||
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||
assert ds.score_answer("x**4", item) == 0.05
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("Not a polynomial", item) == 0.0
|
||||
assert ds.score_answer("x**4", item) == 0.0
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ def test_pool_matrix_score_answer():
|
|||
dataset = PoolMatrixDataset(config)
|
||||
for entry in dataset:
|
||||
assert dataset.score_answer(entry["answer"], entry=entry) == 1
|
||||
assert 0.0 < dataset.score_answer("1 2.0\n3.0 4", entry=entry) <= 0.1
|
||||
assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) in [0.0, 0.1]
|
||||
assert dataset.score_answer("one two three", entry=entry) == 0.0
|
||||
assert dataset.score_answer("", entry=entry) == 0.0
|
||||
assert dataset.score_answer(None, entry=entry) == 0.0
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def test_power_function_score_function():
|
|||
|
||||
# Answer is far from solution
|
||||
answer = str(item["metadata"]["solution"] - 1)
|
||||
assert dataset.score_answer(answer, item) == 0.01
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
||||
# Answer is None
|
||||
answer = None
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ def test_products_scoring():
|
|||
assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0"
|
||||
|
||||
# Test scoring with wrong answer
|
||||
assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01"
|
||||
assert dataset.score_answer("wrong", item) == 0.0, "Wrong answer should score 0.0"
|
||||
|
||||
# Test scoring with partial match (answer contained in response)
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -100,4 +100,4 @@ def test_propositional_logic_dataset_score_answer_incorrect():
|
|||
dataset = PropositionalLogicDataset(PropositionalLogicConfig(size=100, seed=101))
|
||||
for i, item in enumerate(dataset):
|
||||
score = dataset.score_answer("Wrong", item)
|
||||
assert score == 0.01
|
||||
assert score == 0.0
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ def test_quantumlock_items():
|
|||
assert "target_value" in item["metadata"]
|
||||
|
||||
# Verify solution works
|
||||
assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0
|
||||
answer = "".join(item["metadata"]["solution_path"])
|
||||
assert dataset.score_answer(answer=answer, entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
|
||||
|
|
@ -98,17 +99,17 @@ def test_quantumlock_scoring():
|
|||
dataset = QuantumLockDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
solution = item["metadata"]["solution_path"]
|
||||
solution = item["answer"]
|
||||
|
||||
# Test correct solution
|
||||
assert dataset.score_answer(solution, item) == 1.0
|
||||
|
||||
# Test empty/None answers
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
assert dataset.score_answer("", item) == 0.1
|
||||
assert dataset.score_answer("", item) == 0.0
|
||||
|
||||
# Test invalid buttons
|
||||
assert dataset.score_answer("XYZ", item) == 0.1
|
||||
assert dataset.score_answer("XYZ", item) == 0.0
|
||||
|
||||
# Test case insensitivity
|
||||
if solution:
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def test_group_anagrams_dataset_items():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer="gibberish", entry=item) == 0.01
|
||||
assert dataset.score_answer(answer="gibberish", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def test_rearc_scoring_edge_cases():
|
|||
assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0
|
||||
|
||||
# Malformed answer
|
||||
assert dataset.score_answer("[[invalid", entry=item) == 0.01
|
||||
assert dataset.score_answer("[[invalid", entry=item) == 0.0
|
||||
|
||||
# Case sensitivity
|
||||
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ def test_self_reference():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.1
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.1
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# # Medium
|
||||
|
|
@ -34,8 +34,8 @@ def test_self_reference():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.1
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.1
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# # Hard
|
||||
|
|
@ -50,6 +50,6 @@ def test_self_reference():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.1
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.1
|
||||
assert dataset.score_answer(answer=99, entry=item) == 0.0
|
||||
assert dataset.score_answer(answer="99", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ def test_shortest_path_answer():
|
|||
]
|
||||
},
|
||||
}
|
||||
assert dataset.score_answer("right down right down", entry) == 0.01
|
||||
assert dataset.score_answer("right down right down", entry) == 0.0
|
||||
|
||||
# Answer is None
|
||||
entry = {
|
||||
|
|
|
|||
|
|
@ -94,16 +94,16 @@ def test_score_answer_cases():
|
|||
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
|
||||
# Incorrect but properly formatted
|
||||
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
|
||||
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0),
|
||||
# Malformed expressions
|
||||
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
|
||||
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0),
|
||||
# Empty answer
|
||||
("", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
("", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
# Case sensitivity
|
||||
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
|
||||
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0),
|
||||
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
|
||||
# Alternative constant notation
|
||||
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from reasoning_gym.games.sokoban import SokobanConfig, SokobanDataset
|
|||
def test_sokoban():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
|
||||
dataset = SokobanDataset(SokobanConfig(size=10, seed=1234))
|
||||
for i, item in enumerate(dataset):
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
|
||||
# Easy
|
||||
config = SokobanConfig(seed=42, size=20)
|
||||
dataset = SokobanDataset(config)
|
||||
|
|
@ -18,25 +22,29 @@ def test_sokoban():
|
|||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer="RU", entry=item) == 0.1
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# Medium
|
||||
config = SokobanConfig(seed=42, min_h=40, max_h=50, min_w=40, max_w=50, min_boxes=20, max_boxes=30, size=3)
|
||||
dataset = SokobanDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer="RU", entry=item) == 0.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# Hard
|
||||
config = SokobanConfig(seed=42, min_h=400, max_h=500, min_w=400, max_w=500, min_boxes=50, max_boxes=50, size=1)
|
||||
config = SokobanConfig(
|
||||
seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90
|
||||
)
|
||||
dataset = SokobanDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# min == max ranges
|
||||
config = SokobanConfig(
|
||||
seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60
|
||||
)
|
||||
dataset = SokobanDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue