mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
[feat] add test case
This commit is contained in:
parent
f27746be17
commit
c612e2abc1
2 changed files with 71 additions and 1 deletions
|
|
@ -103,3 +103,73 @@ def test_max_letters_constraint():
|
|||
|
||||
# Check total unique letters doesn't exceed 10 (digits 0-9)
|
||||
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
|
||||
|
||||
def test_cryptarithm_score_answer():
|
||||
"""Test the CryptarithmDataset.score_answer method for various correctness levels."""
|
||||
dataset = create_dataset("cryptarithm", seed=42, size=1)
|
||||
puzzle = dataset[0]
|
||||
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
|
||||
|
||||
# 1) Missing 'ANSWER:' => score should be 0.0
|
||||
score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
|
||||
assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}"
|
||||
|
||||
# 2) Correct mapping => expecting 1.0
|
||||
user_answer = f"ANSWER:\n{correct_answer_str}"
|
||||
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
|
||||
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
|
||||
|
||||
# 3) Mismatch number of pairs => score should be 0.1
|
||||
# For instance, drop the last pair
|
||||
splitted = correct_answer_str.split(',')
|
||||
mismatch_str = ','.join(splitted[:-1])
|
||||
user_answer = f"ANSWER:\n{mismatch_str}"
|
||||
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
|
||||
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
|
||||
|
||||
# 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
|
||||
splitted = correct_answer_str.split(',')
|
||||
splitted[0] = splitted[0].replace('=', '') # remove '=' in the first pair
|
||||
parse_error_str = ','.join(splitted)
|
||||
user_answer = f"ANSWER:\n{parse_error_str}"
|
||||
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
|
||||
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
|
||||
|
||||
# 5) Correct number of pairs, but duplicate alphabets => 0.3
|
||||
# This makes the dictionary have fewer unique keys than expected
|
||||
splitted = correct_answer_str.split(',')
|
||||
if len(splitted) > 1:
|
||||
splitted[0] = splitted[1] # Duplicate the second pair in the first position
|
||||
duplicates_str = ','.join(splitted)
|
||||
user_answer = f"ANSWER:\n{duplicates_str}"
|
||||
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
|
||||
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
|
||||
|
||||
# 6) Partial correctness => some correct, some incorrect
|
||||
splitted = correct_answer_str.split(',')
|
||||
correct_mapping = {}
|
||||
for pair in splitted:
|
||||
alpha, num_str = pair.split('=')
|
||||
correct_mapping[alpha] = int(num_str)
|
||||
|
||||
# Make exactly half of them correct, half incorrect
|
||||
total = len(correct_mapping)
|
||||
half = total // 2
|
||||
new_pairs = []
|
||||
i = 0
|
||||
for alpha, num in correct_mapping.items():
|
||||
if i < half:
|
||||
new_pairs.append(f"{alpha}={num}") # keep correct
|
||||
else:
|
||||
new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect
|
||||
i += 1
|
||||
|
||||
partial_answer_str = ','.join(new_pairs)
|
||||
user_answer = f"ANSWER:\n{partial_answer_str}"
|
||||
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
|
||||
|
||||
# The formula is (num_correct / total) * 0.7 + 0.3
|
||||
expected_score = (half / total) * 0.7 + 0.3
|
||||
assert abs(score - expected_score) < 1e-9, (
|
||||
f"Partial correctness: expected {expected_score}, got {score}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue