diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 81ea5bd3..5ec1603a 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -225,7 +225,7 @@ class CryptarithmDataset(ProceduralDataset): The function awards 1.0 for a correct format and answers for all alphabet pairs. Args: - answer (Optional[str]): The user's answer. + answer (Optional[str]): The user's answer already parsed by `extract_answer` answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..." Returns: @@ -236,22 +236,13 @@ class CryptarithmDataset(ProceduralDataset): alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or "" not in answer: - return 0.0 - - number_mapping_line = "" - if "" in answer: - number_mapping_line = answer.split("")[-1] - if "" not in number_mapping_line: - return 0.0 - number_mapping_line = number_mapping_line.split("")[0].strip() # case 1 : pairs are in a list format and the number of pairs matched up - if len(number_mapping_line.split(",")) != len(correct_mapping): + if len(answer.split(",")) != len(correct_mapping): return 0.1 predict_mapping = {} - for pair in number_mapping_line.split(","): + for pair in answer.split(","): try: alphabet, number = pair.strip().split("=") # as the unique alphabet grows we may want this to scale linearly with the number alphabet diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 74563ef0..686ba949 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -112,33 +112,26 @@ def test_cryptarithm_score_answer(): correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." # 1) Missing '' => score should be 0.0 - score = dataset.score_answer(answer=None, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" + # score = dataset.score_answer(answer=None, answer_str=correct_answer_str) + # assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" # 2) Correct mapping => expecting 1.0 - user_answer = f"{correct_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=correct_answer_str, answer_str=correct_answer_str) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" - # 2.1) Missing end tag => expecting 1.0 - user_answer = f"{correct_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}" # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair splitted = correct_answer_str.split(",") mismatch_str = ",".join(splitted[:-1]) - user_answer = f"{mismatch_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=mismatch_str, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) splitted = correct_answer_str.split(",") splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair parse_error_str = ",".join(splitted) - user_answer = f"{parse_error_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=parse_error_str, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" # 5) Correct number of pairs, but duplicate alphabets => 0.3 @@ -147,8 +140,7 @@ def test_cryptarithm_score_answer(): if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position duplicates_str = ",".join(splitted) - user_answer = f"{duplicates_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=duplicates_str, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" # 6) Partial correctness => some correct, some incorrect @@ -171,8 +163,7 @@ def test_cryptarithm_score_answer(): i += 1 partial_answer_str = ",".join(new_pairs) - user_answer = f"{partial_answer_str}" - score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + score = dataset.score_answer(answer=partial_answer_str, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3 expected_score = (half / total) * 0.7 + 0.3