diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 7075e9ec..f3aa4c09 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True). from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Optional, Dict from ..factory import ProceduralDataset, register_dataset @@ -23,10 +23,28 @@ EXAMPLE_CASE = """ ------ GAMES -Answer (one possible solution): +* BASE + BALL = GAMES, two 4-digit numbers sum to 5 digits, so G = 1. -B=7, A=8, S=2, E=9, L=1, G=1, M=0 -Summation: 7829 + 7811 = 15640 (the puzzle might produce a different arrangement, but the principle is the same).""" +* Units: E + L = S (no carry). + +* Tens: S + L = E + 10 (carry 1). Substitute S = E + L to get E + 2L = E + 10, so L = 5. + +* Since S = E + 5 and S is one digit, E < 5. + +* Hundreds: 2A + 1 = M (with carry). + +* Thousands: 2B = A + 10 (carry makes G = 1). So A = 2B - 10. + +* Try B = 7: Then A = 4 and M = 2(4) + 1 = 9. + +* With E < 5, try E = 3: Then S = 8. + +* Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1 + +* Verify: BASE (7483) + BALL (7455) = GAMES (14938). + +ANSWER: +B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @dataclass @@ -178,7 +196,7 @@ class CryptarithmDataset(ProceduralDataset): if self.config.allow_leading_zero else "No leading letter can be zero.\n" ) - + "Provide a mapping from letters to digits that satisfies the equation.\n" + + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" ) if self.config.include_example: question_str += "Here's an example:\n" + EXAMPLE_CASE @@ -202,5 +220,56 @@ class CryptarithmDataset(ProceduralDataset): }, } + def score_answer(self, answer: Optional[str], answer_str: Dict[str, any]) -> float: + """Determine if the solution provided solves the Cryptarithm task. + + The function awards 1.0 for a correct format and answers for all alphabet pairs. + + Args: + answer (Optional[str]): The user's answer. + answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..." + + Returns: + float: The computed score between 0.0 and 1.0. + """ + correct_mapping = {} + for pair in answer_str.split(','): + alphabet, number = pair.split('=') + correct_mapping[alphabet] = int(number) + + if answer == None or 'ANSWER:' not in answer: + return 0.0 + + number_mapping_line = '' + if 'ANSWER:' in answer: + number_mapping_line = answer.split('ANSWER:\n')[-1] + + # case 1 : pairs are in a list format and the number of pairs matched up + if len(number_mapping_line.split(',')) != len(correct_mapping): + return 0.1 + + predict_mapping = {} + for pair in number_mapping_line.split(','): + try: + alphabet, number = pair.strip().split('=') + # as the unique alphabet grows we may want this to scale linearly with the number alphabet + predict_mapping[alphabet] = int(number) + except ValueError: + return 0.15 + # case 2 : all alphabet has correct format ALPHABET=NUMBER format + if len(predict_mapping) != len(correct_mapping): + return 0.3 + + # case 3 : partial score for the number of correct mapping answer + total_correct, total = 0, 0 + for alphabet, number in correct_mapping: + total += 1 + if alphabet in predict_mapping: + if predict_mapping[alphabet] == number: + total_correct += 1 + + # note: linear relationship is probably not good? + return (total_correct/total)*0.7 + 0.3 + register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig)