From 288f632c5c94903d526d45e49052053ace85d1ca Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Tue, 18 Feb 2025 21:33:14 +0800 Subject: [PATCH 1/6] [feat] added score_answer function --- reasoning_gym/algorithmic/cryptarithm.py | 79 ++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 7075e9ec..f3aa4c09 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True). from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Optional, Dict from ..factory import ProceduralDataset, register_dataset @@ -23,10 +23,28 @@ EXAMPLE_CASE = """ ------ GAMES -Answer (one possible solution): +* BASE + BALL = GAMES, two 4-digit numbers sum to 5 digits, so G = 1. -B=7, A=8, S=2, E=9, L=1, G=1, M=0 -Summation: 7829 + 7811 = 15640 (the puzzle might produce a different arrangement, but the principle is the same).""" +* Units: E + L = S (no carry). + +* Tens: S + L = E + 10 (carry 1). Substitute S = E + L to get E + 2L = E + 10, so L = 5. + +* Since S = E + 5 and S is one digit, E < 5. + +* Hundreds: 2A + 1 = M (with carry). + +* Thousands: 2B = A + 10 (carry makes G = 1). So A = 2B - 10. + +* Try B = 7: Then A = 4 and M = 2(4) + 1 = 9. + +* With E < 5, try E = 3: Then S = 8. + +* Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1 + +* Verify: BASE (7483) + BALL (7455) = GAMES (14938). + +ANSWER: +B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @dataclass @@ -178,7 +196,7 @@ class CryptarithmDataset(ProceduralDataset): if self.config.allow_leading_zero else "No leading letter can be zero.\n" ) - + "Provide a mapping from letters to digits that satisfies the equation.\n" + + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" ) if self.config.include_example: question_str += "Here's an example:\n" + EXAMPLE_CASE @@ -202,5 +220,56 @@ class CryptarithmDataset(ProceduralDataset): }, } + def score_answer(self, answer: Optional[str], answer_str: Dict[str, any]) -> float: + """Determine if the solution provided solves the Cryptarithm task. + + The function awards 1.0 for a correct format and answers for all alphabet pairs. + + Args: + answer (Optional[str]): The user's answer. + answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..." + + Returns: + float: The computed score between 0.0 and 1.0. + """ + correct_mapping = {} + for pair in answer_str.split(','): + alphabet, number = pair.split('=') + correct_mapping[alphabet] = int(number) + + if answer == None or 'ANSWER:' not in answer: + return 0.0 + + number_mapping_line = '' + if 'ANSWER:' in answer: + number_mapping_line = answer.split('ANSWER:\n')[-1] + + # case 1 : pairs are in a list format and the number of pairs matched up + if len(number_mapping_line.split(',')) != len(correct_mapping): + return 0.1 + + predict_mapping = {} + for pair in number_mapping_line.split(','): + try: + alphabet, number = pair.strip().split('=') + # as the unique alphabet grows we may want this to scale linearly with the number alphabet + predict_mapping[alphabet] = int(number) + except ValueError: + return 0.15 + # case 2 : all alphabet has correct format ALPHABET=NUMBER format + if len(predict_mapping) != len(correct_mapping): + return 0.3 + + # case 3 : partial score for the number of correct mapping answer + total_correct, total = 0, 0 + for alphabet, number in correct_mapping: + total += 1 + if alphabet in predict_mapping: + if predict_mapping[alphabet] == number: + total_correct += 1 + + # note: linear relationship is probably not good? + return (total_correct/total)*0.7 + 0.3 + register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig) From c612e2abc1e5aceb06d840ed92c75252aaddc952 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Tue, 18 Feb 2025 21:45:51 +0800 Subject: [PATCH 2/6] [feat] add test case --- reasoning_gym/algorithmic/cryptarithm.py | 2 +- tests/test_cryptarithm.py | 70 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index f3aa4c09..bc930848 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -262,7 +262,7 @@ class CryptarithmDataset(ProceduralDataset): # case 3 : partial score for the number of correct mapping answer total_correct, total = 0, 0 - for alphabet, number in correct_mapping: + for alphabet, number in correct_mapping.items(): total += 1 if alphabet in predict_mapping: if predict_mapping[alphabet] == number: diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 0ae3ea7f..e0ae9c85 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -103,3 +103,73 @@ def test_max_letters_constraint(): # Check total unique letters doesn't exceed 10 (digits 0-9) assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10" + +def test_cryptarithm_score_answer(): + """Test the CryptarithmDataset.score_answer method for various correctness levels.""" + dataset = create_dataset("cryptarithm", seed=42, size=1) + puzzle = dataset[0] + correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." + + # 1) Missing 'ANSWER:' => score should be 0.0 + score = dataset.score_answer(answer=None, answer_str=correct_answer_str) + assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}" + + # 2) Correct mapping => expecting 1.0 + user_answer = f"ANSWER:\n{correct_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" + + # 3) Mismatch number of pairs => score should be 0.1 + # For instance, drop the last pair + splitted = correct_answer_str.split(',') + mismatch_str = ','.join(splitted[:-1]) + user_answer = f"ANSWER:\n{mismatch_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" + + # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) + splitted = correct_answer_str.split(',') + splitted[0] = splitted[0].replace('=', '') # remove '=' in the first pair + parse_error_str = ','.join(splitted) + user_answer = f"ANSWER:\n{parse_error_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" + + # 5) Correct number of pairs, but duplicate alphabets => 0.3 + # This makes the dictionary have fewer unique keys than expected + splitted = correct_answer_str.split(',') + if len(splitted) > 1: + splitted[0] = splitted[1] # Duplicate the second pair in the first position + duplicates_str = ','.join(splitted) + user_answer = f"ANSWER:\n{duplicates_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" + + # 6) Partial correctness => some correct, some incorrect + splitted = correct_answer_str.split(',') + correct_mapping = {} + for pair in splitted: + alpha, num_str = pair.split('=') + correct_mapping[alpha] = int(num_str) + + # Make exactly half of them correct, half incorrect + total = len(correct_mapping) + half = total // 2 + new_pairs = [] + i = 0 + for alpha, num in correct_mapping.items(): + if i < half: + new_pairs.append(f"{alpha}={num}") # keep correct + else: + new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect + i += 1 + + partial_answer_str = ','.join(new_pairs) + user_answer = f"ANSWER:\n{partial_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + + # The formula is (num_correct / total) * 0.7 + 0.3 + expected_score = (half / total) * 0.7 + 0.3 + assert abs(score - expected_score) < 1e-9, ( + f"Partial correctness: expected {expected_score}, got {score}" + ) \ No newline at end of file From bbc31e42910a28c54dcb50b8486efcb458bb282f Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Tue, 18 Feb 2025 21:48:54 +0800 Subject: [PATCH 3/6] [fix] pre-commit fix --- reasoning_gym/algorithmic/cryptarithm.py | 26 +++++++++++------------ tests/test_cryptarithm.py | 27 ++++++++++++------------ 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index bc930848..4eccc9e4 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True). from dataclasses import dataclass from random import Random -from typing import Optional, Dict +from typing import Dict, Optional from ..factory import ProceduralDataset, register_dataset @@ -42,7 +42,7 @@ EXAMPLE_CASE = """ * Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1 * Verify: BASE (7483) + BALL (7455) = GAMES (14938). - + ANSWER: B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @@ -233,25 +233,25 @@ class CryptarithmDataset(ProceduralDataset): float: The computed score between 0.0 and 1.0. """ correct_mapping = {} - for pair in answer_str.split(','): - alphabet, number = pair.split('=') + for pair in answer_str.split(","): + alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or 'ANSWER:' not in answer: + if answer == None or "ANSWER:" not in answer: return 0.0 - number_mapping_line = '' - if 'ANSWER:' in answer: - number_mapping_line = answer.split('ANSWER:\n')[-1] - + number_mapping_line = "" + if "ANSWER:" in answer: + number_mapping_line = answer.split("ANSWER:\n")[-1] + # case 1 : pairs are in a list format and the number of pairs matched up - if len(number_mapping_line.split(',')) != len(correct_mapping): + if len(number_mapping_line.split(",")) != len(correct_mapping): return 0.1 predict_mapping = {} - for pair in number_mapping_line.split(','): + for pair in number_mapping_line.split(","): try: - alphabet, number = pair.strip().split('=') + alphabet, number = pair.strip().split("=") # as the unique alphabet grows we may want this to scale linearly with the number alphabet predict_mapping[alphabet] = int(number) except ValueError: @@ -269,7 +269,7 @@ class CryptarithmDataset(ProceduralDataset): total_correct += 1 # note: linear relationship is probably not good? - return (total_correct/total)*0.7 + 0.3 + return (total_correct / total) * 0.7 + 0.3 register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig) diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index e0ae9c85..942f180e 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -104,6 +104,7 @@ def test_max_letters_constraint(): # Check total unique letters doesn't exceed 10 (digits 0-9) assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10" + def test_cryptarithm_score_answer(): """Test the CryptarithmDataset.score_answer method for various correctness levels.""" dataset = create_dataset("cryptarithm", seed=42, size=1) @@ -121,35 +122,35 @@ def test_cryptarithm_score_answer(): # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair - splitted = correct_answer_str.split(',') - mismatch_str = ','.join(splitted[:-1]) + splitted = correct_answer_str.split(",") + mismatch_str = ",".join(splitted[:-1]) user_answer = f"ANSWER:\n{mismatch_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) - splitted = correct_answer_str.split(',') - splitted[0] = splitted[0].replace('=', '') # remove '=' in the first pair - parse_error_str = ','.join(splitted) + splitted = correct_answer_str.split(",") + splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair + parse_error_str = ",".join(splitted) user_answer = f"ANSWER:\n{parse_error_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" # 5) Correct number of pairs, but duplicate alphabets => 0.3 # This makes the dictionary have fewer unique keys than expected - splitted = correct_answer_str.split(',') + splitted = correct_answer_str.split(",") if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position - duplicates_str = ','.join(splitted) + duplicates_str = ",".join(splitted) user_answer = f"ANSWER:\n{duplicates_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" # 6) Partial correctness => some correct, some incorrect - splitted = correct_answer_str.split(',') + splitted = correct_answer_str.split(",") correct_mapping = {} for pair in splitted: - alpha, num_str = pair.split('=') + alpha, num_str = pair.split("=") correct_mapping[alpha] = int(num_str) # Make exactly half of them correct, half incorrect @@ -159,17 +160,15 @@ def test_cryptarithm_score_answer(): i = 0 for alpha, num in correct_mapping.items(): if i < half: - new_pairs.append(f"{alpha}={num}") # keep correct + new_pairs.append(f"{alpha}={num}") # keep correct else: new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect i += 1 - partial_answer_str = ','.join(new_pairs) + partial_answer_str = ",".join(new_pairs) user_answer = f"ANSWER:\n{partial_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3 expected_score = (half / total) * 0.7 + 0.3 - assert abs(score - expected_score) < 1e-9, ( - f"Partial correctness: expected {expected_score}, got {score}" - ) \ No newline at end of file + assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}" From 9a2e9e949edee6bdecab9b1a3de9fb1b2552e7bf Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Wed, 19 Feb 2025 08:40:31 +0800 Subject: [PATCH 4/6] [fix] normalize to --- reasoning_gym/algorithmic/cryptarithm.py | 14 ++++++++------ tests/test_cryptarithm.py | 19 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 4eccc9e4..81ea5bd3 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -43,8 +43,7 @@ EXAMPLE_CASE = """ * Verify: BASE (7483) + BALL (7455) = GAMES (14938). -ANSWER: -B=7, A=4, S=8, E=3, L=5, M=9, G=1""" +B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @dataclass @@ -196,7 +195,7 @@ class CryptarithmDataset(ProceduralDataset): if self.config.allow_leading_zero else "No leading letter can be zero.\n" ) - + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" + + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" ) if self.config.include_example: question_str += "Here's an example:\n" + EXAMPLE_CASE @@ -237,12 +236,15 @@ class CryptarithmDataset(ProceduralDataset): alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or "ANSWER:" not in answer: + if answer == None or "" not in answer: return 0.0 number_mapping_line = "" - if "ANSWER:" in answer: - number_mapping_line = answer.split("ANSWER:\n")[-1] + if "" in answer: + number_mapping_line = answer.split("")[-1] + if "" not in number_mapping_line: + return 0.0 + number_mapping_line = number_mapping_line.split("")[0].strip() # case 1 : pairs are in a list format and the number of pairs matched up if len(number_mapping_line.split(",")) != len(correct_mapping): diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 942f180e..74563ef0 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -111,20 +111,25 @@ def test_cryptarithm_score_answer(): puzzle = dataset[0] correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." - # 1) Missing 'ANSWER:' => score should be 0.0 + # 1) Missing '' => score should be 0.0 score = dataset.score_answer(answer=None, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}" + assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" # 2) Correct mapping => expecting 1.0 - user_answer = f"ANSWER:\n{correct_answer_str}" + user_answer = f"{correct_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" + # 2.1) Missing end tag => expecting 1.0 + user_answer = f"{correct_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}" + # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair splitted = correct_answer_str.split(",") mismatch_str = ",".join(splitted[:-1]) - user_answer = f"ANSWER:\n{mismatch_str}" + user_answer = f"{mismatch_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" @@ -132,7 +137,7 @@ def test_cryptarithm_score_answer(): splitted = correct_answer_str.split(",") splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair parse_error_str = ",".join(splitted) - user_answer = f"ANSWER:\n{parse_error_str}" + user_answer = f"{parse_error_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" @@ -142,7 +147,7 @@ def test_cryptarithm_score_answer(): if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position duplicates_str = ",".join(splitted) - user_answer = f"ANSWER:\n{duplicates_str}" + user_answer = f"{duplicates_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" @@ -166,7 +171,7 @@ def test_cryptarithm_score_answer(): i += 1 partial_answer_str = ",".join(new_pairs) - user_answer = f"ANSWER:\n{partial_answer_str}" + user_answer = f"{partial_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3 From 407b21232696fe2c8440ab3dcc81668aa03fcbef Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Thu, 20 Feb 2025 16:57:51 +0800 Subject: [PATCH 5/6] [feat] remove answer parsing since its already handled --- reasoning_gym/algorithmic/cryptarithm.py | 15 +++------------ tests/test_cryptarithm.py | 23 +++++++---------------- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 81ea5bd3..5ec1603a 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -225,7 +225,7 @@ class CryptarithmDataset(ProceduralDataset): The function awards 1.0 for a correct format and answers for all alphabet pairs. Args: - answer (Optional[str]): The user's answer. + answer (Optional[str]): The user's answer already parsed by `extract_answer` answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..." Returns: @@ -236,22 +236,13 @@ class CryptarithmDataset(ProceduralDataset): alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or "" not in answer: - return 0.0 - - number_mapping_line = "" - if "" in answer: - number_mapping_line = answer.split("")[-1] - if "" not in number_mapping_line: - return 0.0 - number_mapping_line = number_mapping_line.split("")[0].strip() # case 1 : pairs are in a list format and the number of pairs matched up - if len(number_mapping_line.split(",")) != len(correct_mapping): + if len(answer.split(",")) != len(correct_mapping): return 0.1 predict_mapping = {} - for pair in number_mapping_line.split(","): + for pair in answer.split(","): try: alphabet, number = pair.strip().split("=") # as the unique alphabet grows we may want this to scale linearly with the number alphabet diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 74563ef0..686ba949 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -112,33 +112,26 @@ def test_cryptarithm_score_answer(): correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." # 1) Missing '' => score should be 0.0 - score = dataset.score_answer(answer=None, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" + # score = dataset.score_answer(answer=None, answer_str=correct_answer_str) + # assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" # 2) Correct mapping => expecting 1.0 - user_answer = f"{correct_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=correct_answer_str, answer_str=correct_answer_str) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" - # 2.1) Missing end tag => expecting 1.0 - user_answer = f"{correct_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}" # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair splitted = correct_answer_str.split(",") mismatch_str = ",".join(splitted[:-1]) - user_answer = f"{mismatch_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=mismatch_str, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) splitted = correct_answer_str.split(",") splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair parse_error_str = ",".join(splitted) - user_answer = f"{parse_error_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=parse_error_str, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" # 5) Correct number of pairs, but duplicate alphabets => 0.3 @@ -147,8 +140,7 @@ def test_cryptarithm_score_answer(): if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position duplicates_str = ",".join(splitted) - user_answer = f"{duplicates_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=duplicates_str, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" # 6) Partial correctness => some correct, some incorrect @@ -171,8 +163,7 @@ def test_cryptarithm_score_answer(): i += 1 partial_answer_str = ",".join(new_pairs) - user_answer = f"{partial_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=partial_answer_str, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3 expected_score = (half / total) * 0.7 + 0.3 From 44559aac952a0d1ed1499014794deb2d27a7cac5 Mon Sep 17 00:00:00 2001 From: theblackcat102 <13172147+theblackcat102@users.noreply.github.com> Date: Thu, 20 Feb 2025 17:00:18 +0800 Subject: [PATCH 6/6] [fix] precommit not happy --- reasoning_gym/algorithmic/cryptarithm.py | 1 - tests/test_cryptarithm.py | 1 - 2 files changed, 2 deletions(-) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 5ec1603a..a7b5236f 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -236,7 +236,6 @@ class CryptarithmDataset(ProceduralDataset): alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - # case 1 : pairs are in a list format and the number of pairs matched up if len(answer.split(",")) != len(correct_mapping): return 0.1 diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 686ba949..b704e3b4 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -119,7 +119,6 @@ def test_cryptarithm_score_answer(): score = dataset.score_answer(answer=correct_answer_str, answer_str=correct_answer_str) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" - # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair splitted = correct_answer_str.split(",")