mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
add ArcAgiDataset class, fix score_entry() metadata params
This commit is contained in:
parent
2ad0965fdc
commit
4e49806d22
20 changed files with 194 additions and 93 deletions
|
|
@ -52,30 +52,30 @@ def test_complex_arithmetic_scoring():
|
|||
dataset = ComplexArithmeticDataset(config)
|
||||
|
||||
# Test case with answer 3 + 2i
|
||||
metadata = {"result": (3.0, 2.0)}
|
||||
entry = {"metadata": {"result": (3.0, 2.0)}}
|
||||
|
||||
# Test exact matches (should get score of 1.0)
|
||||
assert dataset.score_answer("3 + 2i", metadata) == 1.0
|
||||
assert dataset.score_answer("3+2i", metadata) == 1.0
|
||||
assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0
|
||||
assert dataset.score_answer("3 + 2i", entry) == 1.0
|
||||
assert dataset.score_answer("3+2i", entry) == 1.0
|
||||
assert dataset.score_answer("3.0 + 2.0i", entry) == 1.0
|
||||
|
||||
# Test answers with small errors (should get high but < 1.0 scores)
|
||||
print(dataset.score_answer("3.1 + 2i", metadata))
|
||||
assert 0.9 < dataset.score_answer("3.1 + 2i", metadata) < 1.0
|
||||
assert 0.9 < dataset.score_answer("3 + 2.1i", metadata) < 1.0
|
||||
assert 0.7 < dataset.score_answer("3.1 + 2.1i", metadata) < 0.95
|
||||
print(dataset.score_answer("3.1 + 2i", entry))
|
||||
assert 0.9 < dataset.score_answer("3.1 + 2i", entry) < 1.0
|
||||
assert 0.9 < dataset.score_answer("3 + 2.1i", entry) < 1.0
|
||||
assert 0.7 < dataset.score_answer("3.1 + 2.1i", entry) < 0.95
|
||||
|
||||
# Test answers with moderate errors (should get medium scores)
|
||||
assert 0.3 < dataset.score_answer("4 + 2i", metadata) < 0.4
|
||||
assert 0.3 < dataset.score_answer("3 + 3i", metadata) < 0.4
|
||||
assert 0.3 < dataset.score_answer("4 + 2i", entry) < 0.4
|
||||
assert 0.3 < dataset.score_answer("3 + 3i", entry) < 0.4
|
||||
|
||||
# Test answers with large errors (should get very low scores)
|
||||
assert dataset.score_answer("10 + 10i", metadata) < 0.01
|
||||
assert dataset.score_answer("10 + 10i", entry) < 0.01
|
||||
|
||||
# Test invalid answers (should get 0.0)
|
||||
assert dataset.score_answer("invalid", metadata) == 0.0
|
||||
assert dataset.score_answer(None, metadata) == 0.0
|
||||
assert dataset.score_answer("inf + 2i", metadata) == 0.0
|
||||
assert dataset.score_answer("invalid", entry) == 0.0
|
||||
assert dataset.score_answer(None, entry) == 0.0
|
||||
assert dataset.score_answer("inf + 2i", entry) == 0.0
|
||||
|
||||
|
||||
def test_complex_arithmetic_division_by_zero():
|
||||
|
|
|
|||
|
|
@ -66,13 +66,13 @@ def test_countdown_game_items():
|
|||
expr = item["metadata"]["expression"]
|
||||
|
||||
# check score
|
||||
assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 # correct answer
|
||||
assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 # wrong answer but an attempt
|
||||
assert dataset.score_answer(answer=expr, entry=item) == 1.0 # correct answer
|
||||
assert dataset.score_answer(answer="45+2", entry=item) == 0.05 # wrong answer but an attempt
|
||||
assert (
|
||||
dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01
|
||||
dataset.score_answer(answer="a wrong solution", entry=item) == 0.01
|
||||
) # wrong answer but incorrectly formatted
|
||||
assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 # wrong answer but empty string
|
||||
assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 # no answer
|
||||
assert dataset.score_answer(answer="", entry=item) == 0.01 # wrong answer but empty string
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0 # no answer
|
||||
|
||||
try:
|
||||
result = eval(expr) # Safe here since we control expression generation
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def test_verify_answer():
|
|||
dataset = IntermediateIntegrationDataset(config)
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
score = dataset.score_answer(item["answer"], item["metadata"])
|
||||
score = dataset.score_answer(answer=item["answer"], entry=item)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
|
|
@ -140,5 +140,6 @@ def test_score_answer_cases():
|
|||
]
|
||||
|
||||
for answer, metadata, expected in test_cases:
|
||||
score = dataset.score_answer(answer, metadata)
|
||||
dummy_entry = {"metadata": metadata}
|
||||
score = dataset.score_answer(answer, entry=dummy_entry)
|
||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||
|
|
|
|||
|
|
@ -72,21 +72,20 @@ def test_score_answer():
|
|||
|
||||
for item in dataset:
|
||||
correct_answer = item["answer"]
|
||||
metadata = item["metadata"]
|
||||
|
||||
# Correct answer should score 1.0
|
||||
assert dataset.score_answer(correct_answer, metadata) == 1.0
|
||||
assert dataset.score_answer(correct_answer, entry=item) == 1.0
|
||||
|
||||
# Incorrect answer (palindrome, but not correct one) should score 0.05
|
||||
pal_letters = "racecar" if "racecar" != correct_answer else "aba"
|
||||
assert dataset.score_answer(pal_letters, metadata) == 0.05
|
||||
assert dataset.score_answer(pal_letters, entry=item) == 0.05
|
||||
|
||||
# Incorrect answer (not palindrome) should score 0.02
|
||||
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
|
||||
assert dataset.score_answer(wrong_letters, metadata) == 0.02
|
||||
assert dataset.score_answer(wrong_letters, entry=item) == 0.02
|
||||
|
||||
# Empty String input should score 0.01
|
||||
assert dataset.score_answer("", metadata) == 0.01
|
||||
assert dataset.score_answer("", entry=item) == 0.01
|
||||
|
||||
# Empty input should score 0.0
|
||||
assert dataset.score_answer(None, metadata) == 0.0
|
||||
assert dataset.score_answer(None, entry=item) == 0.0
|
||||
|
|
|
|||
|
|
@ -137,10 +137,10 @@ def test_score_function():
|
|||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
||||
assert ds.score_answer(None, ds[0]) == 0.00
|
||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||
|
||||
|
||||
def test_multivariate_score_function():
|
||||
|
|
@ -160,7 +160,7 @@ def test_multivariate_score_function():
|
|||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
||||
assert ds.score_answer(None, ds[0]) == 0.00
|
||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def test_rearc_solution_validation():
|
|||
for item in dataset:
|
||||
# Test correct solution
|
||||
correct = format_board(item["metadata"]["output"], dataset.board_format_opts)
|
||||
assert dataset.score_answer(correct, item["metadata"]) == 1.0
|
||||
assert dataset.score_answer(correct, entry=item) == 1.0
|
||||
|
||||
# Test invalid format
|
||||
invalid_grid = """
|
||||
|
|
@ -63,10 +63,10 @@ def test_rearc_solution_validation():
|
|||
7 8 7
|
||||
0 0 0
|
||||
"""
|
||||
assert dataset.score_answer(invalid_grid, item["metadata"]) == 0.05
|
||||
assert dataset.score_answer(invalid_grid, entry=item) == 0.05
|
||||
|
||||
# Test empty answer
|
||||
assert dataset.score_answer(None, item["metadata"]) == 0.0
|
||||
assert dataset.score_answer(None, entry=item) == 0.0
|
||||
|
||||
|
||||
def test_rearc_scoring_edge_cases():
|
||||
|
|
@ -77,11 +77,11 @@ def test_rearc_scoring_edge_cases():
|
|||
for item in dataset:
|
||||
# Partial match
|
||||
partial = format_board([[0, 0], [0, 0]], dataset.board_format_opts)
|
||||
assert 0.0 < dataset.score_answer(partial, item["metadata"]) < 1.0
|
||||
assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0
|
||||
|
||||
# Malformed answer
|
||||
assert dataset.score_answer("[[invalid", item["metadata"]) == 0.01
|
||||
assert dataset.score_answer("[[invalid", entry=item) == 0.01
|
||||
|
||||
# Case sensitivity
|
||||
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()
|
||||
assert dataset.score_answer(answer, item["metadata"]) == 1.0
|
||||
assert dataset.score_answer(answer, entry=item) == 1.0
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def test_verify_answer():
|
|||
dataset = SimpleIntegrationDataset(config)
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
score = dataset.score_answer(item["answer"], item["metadata"])
|
||||
score = dataset.score_answer(item["answer"], item)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
|
|
@ -113,5 +113,6 @@ def test_score_answer_cases():
|
|||
]
|
||||
|
||||
for answer, metadata, expected in test_cases:
|
||||
score = dataset.score_answer(answer, metadata)
|
||||
dummy_entry = {"metadata": metadata}
|
||||
score = dataset.score_answer(answer=answer, entry=dummy_entry)
|
||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||
|
|
|
|||
|
|
@ -245,27 +245,26 @@ def test_score_answer():
|
|||
dataset = HanoiDataset(config)
|
||||
# Pick one instance from the dataset for testing.
|
||||
item = dataset[0]
|
||||
metadata = item["metadata"]
|
||||
correct_answer = item["answer"]
|
||||
|
||||
# 1. Correct answer should yield full reward.
|
||||
score_correct = dataset.score_answer(answer=correct_answer, metadata=metadata)
|
||||
score_correct = dataset.score_answer(answer=correct_answer, entry=item)
|
||||
assert score_correct == 1.0, f"Correct answer score {score_correct} is not 1.0."
|
||||
|
||||
# 2. A badly formatted answer should yield minimal reward (0.01).
|
||||
score_bad_format = dataset.score_answer(answer="a wrong solution", metadata=metadata)
|
||||
score_bad_format = dataset.score_answer(answer="a wrong solution", entry=item)
|
||||
assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01."
|
||||
|
||||
# 3. An answer that is validly formatted but unsolved.
|
||||
# For example, remove the last move from the correct answer.
|
||||
unfinished_answer = correct_answer[:-1]
|
||||
score_unsolved = dataset.score_answer(answer=unfinished_answer, metadata=metadata)
|
||||
score_unsolved = dataset.score_answer(answer=unfinished_answer, entry=item)
|
||||
assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05."
|
||||
|
||||
# 4. An empty answer should yield 0.01.
|
||||
score_empty = dataset.score_answer(answer="", metadata=metadata)
|
||||
score_empty = dataset.score_answer(answer="", entry=item)
|
||||
assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01."
|
||||
|
||||
# 5. A None answer should yield 0.0.
|
||||
score_none = dataset.score_answer(answer=None, metadata=metadata)
|
||||
score_none = dataset.score_answer(answer=None, entry=item)
|
||||
assert score_none == 0.0, f"None answer score {score_none} is not 0.0."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue