[fix] pre-commit fix

This commit is contained in:
theblackcat102 2025-02-18 21:48:54 +08:00
parent c612e2abc1
commit bbc31e4291
2 changed files with 26 additions and 27 deletions

View file

@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True).
from dataclasses import dataclass
from random import Random
from typing import Optional, Dict
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
@ -42,7 +42,7 @@ EXAMPLE_CASE = """
* Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1
* Verify: BASE (7483) + BALL (7455) = GAMES (14938).
ANSWER:
B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
@ -233,25 +233,25 @@ class CryptarithmDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
correct_mapping = {}
for pair in answer_str.split(','):
alphabet, number = pair.split('=')
for pair in answer_str.split(","):
alphabet, number = pair.split("=")
correct_mapping[alphabet] = int(number)
if answer == None or 'ANSWER:' not in answer:
if answer == None or "ANSWER:" not in answer:
return 0.0
number_mapping_line = ''
if 'ANSWER:' in answer:
number_mapping_line = answer.split('ANSWER:\n')[-1]
number_mapping_line = ""
if "ANSWER:" in answer:
number_mapping_line = answer.split("ANSWER:\n")[-1]
# case 1 : pairs are in a list format and the number of pairs matched up
if len(number_mapping_line.split(',')) != len(correct_mapping):
if len(number_mapping_line.split(",")) != len(correct_mapping):
return 0.1
predict_mapping = {}
for pair in number_mapping_line.split(','):
for pair in number_mapping_line.split(","):
try:
alphabet, number = pair.strip().split('=')
alphabet, number = pair.strip().split("=")
# as the unique alphabet grows we may want this to scale linearly with the number alphabet
predict_mapping[alphabet] = int(number)
except ValueError:
@ -269,7 +269,7 @@ class CryptarithmDataset(ProceduralDataset):
total_correct += 1
# note: linear relationship is probably not good?
return (total_correct/total)*0.7 + 0.3
return (total_correct / total) * 0.7 + 0.3
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig)