ensure reward is float

This commit is contained in:
Andreas Koepf 2025-02-16 16:27:12 +01:00
parent c858d1f236
commit 4c47b7966f
3 changed files with 5 additions and 5 deletions

View file

@ -90,7 +90,7 @@ class GroupAnagramsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Score a single Group Anagrams question"""
reward = 0
reward = 0.0
if answer is not None:
try:
answer = json.loads(answer)
@ -98,11 +98,11 @@ class GroupAnagramsDataset(ProceduralDataset):
answer_str = json.dumps(self._sort_nested_list(answer))
oracle_str = json.dumps(self._sort_nested_list(oracle))
if answer_str == oracle_str:
reward = 1
reward = 1.0
else:
reward = 0.01
except Exception:
reward = 0
reward = 0.0
return reward
def __getitem__(self, idx: int) -> dict:

View file

@ -93,7 +93,7 @@ class SentenceReorderingDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
try:

View file

@ -50,7 +50,7 @@ class SpellBackwardDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
try: