diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index bc930848..4eccc9e4 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True). from dataclasses import dataclass from random import Random -from typing import Optional, Dict +from typing import Dict, Optional from ..factory import ProceduralDataset, register_dataset @@ -42,7 +42,7 @@ EXAMPLE_CASE = """ * Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1 * Verify: BASE (7483) + BALL (7455) = GAMES (14938). - + ANSWER: B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @@ -233,25 +233,25 @@ class CryptarithmDataset(ProceduralDataset): float: The computed score between 0.0 and 1.0. """ correct_mapping = {} - for pair in answer_str.split(','): - alphabet, number = pair.split('=') + for pair in answer_str.split(","): + alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or 'ANSWER:' not in answer: + if answer == None or "ANSWER:" not in answer: return 0.0 - number_mapping_line = '' - if 'ANSWER:' in answer: - number_mapping_line = answer.split('ANSWER:\n')[-1] - + number_mapping_line = "" + if "ANSWER:" in answer: + number_mapping_line = answer.split("ANSWER:\n")[-1] + # case 1 : pairs are in a list format and the number of pairs matched up - if len(number_mapping_line.split(',')) != len(correct_mapping): + if len(number_mapping_line.split(",")) != len(correct_mapping): return 0.1 predict_mapping = {} - for pair in number_mapping_line.split(','): + for pair in number_mapping_line.split(","): try: - alphabet, number = pair.strip().split('=') + alphabet, number = pair.strip().split("=") # as the unique alphabet grows we may want this to scale linearly with the number alphabet predict_mapping[alphabet] = int(number) except ValueError: @@ -269,7 +269,7 @@ class CryptarithmDataset(ProceduralDataset): total_correct += 1 # note: linear relationship is probably not good? - return (total_correct/total)*0.7 + 0.3 + return (total_correct / total) * 0.7 + 0.3 register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig) diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index e0ae9c85..942f180e 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -104,6 +104,7 @@ def test_max_letters_constraint(): # Check total unique letters doesn't exceed 10 (digits 0-9) assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10" + def test_cryptarithm_score_answer(): """Test the CryptarithmDataset.score_answer method for various correctness levels.""" dataset = create_dataset("cryptarithm", seed=42, size=1) @@ -121,35 +122,35 @@ def test_cryptarithm_score_answer(): # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair - splitted = correct_answer_str.split(',') - mismatch_str = ','.join(splitted[:-1]) + splitted = correct_answer_str.split(",") + mismatch_str = ",".join(splitted[:-1]) user_answer = f"ANSWER:\n{mismatch_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) - splitted = correct_answer_str.split(',') - splitted[0] = splitted[0].replace('=', '') # remove '=' in the first pair - parse_error_str = ','.join(splitted) + splitted = correct_answer_str.split(",") + splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair + parse_error_str = ",".join(splitted) user_answer = f"ANSWER:\n{parse_error_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" # 5) Correct number of pairs, but duplicate alphabets => 0.3 # This makes the dictionary have fewer unique keys than expected - splitted = correct_answer_str.split(',') + splitted = correct_answer_str.split(",") if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position - duplicates_str = ','.join(splitted) + duplicates_str = ",".join(splitted) user_answer = f"ANSWER:\n{duplicates_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" # 6) Partial correctness => some correct, some incorrect - splitted = correct_answer_str.split(',') + splitted = correct_answer_str.split(",") correct_mapping = {} for pair in splitted: - alpha, num_str = pair.split('=') + alpha, num_str = pair.split("=") correct_mapping[alpha] = int(num_str) # Make exactly half of them correct, half incorrect @@ -159,17 +160,15 @@ def test_cryptarithm_score_answer(): i = 0 for alpha, num in correct_mapping.items(): if i < half: - new_pairs.append(f"{alpha}={num}") # keep correct + new_pairs.append(f"{alpha}={num}") # keep correct else: new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect i += 1 - partial_answer_str = ','.join(new_pairs) + partial_answer_str = ",".join(new_pairs) user_answer = f"ANSWER:\n{partial_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3 expected_score = (half / total) * 0.7 + 0.3 - assert abs(score - expected_score) < 1e-9, ( - f"Partial correctness: expected {expected_score}, got {score}" - ) \ No newline at end of file + assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}"