ensure reward is float

This commit is contained in:
Andreas Koepf 2025-02-16 16:27:12 +01:00
parent c858d1f236
commit 4c47b7966f
3 changed files with 5 additions and 5 deletions

View file

@ -90,7 +90,7 @@ class GroupAnagramsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Score a single Group Anagrams question""" """Score a single Group Anagrams question"""
reward = 0 reward = 0.0
if answer is not None: if answer is not None:
try: try:
answer = json.loads(answer) answer = json.loads(answer)
@ -98,11 +98,11 @@ class GroupAnagramsDataset(ProceduralDataset):
answer_str = json.dumps(self._sort_nested_list(answer)) answer_str = json.dumps(self._sort_nested_list(answer))
oracle_str = json.dumps(self._sort_nested_list(oracle)) oracle_str = json.dumps(self._sort_nested_list(oracle))
if answer_str == oracle_str: if answer_str == oracle_str:
reward = 1 reward = 1.0
else: else:
reward = 0.01 reward = 0.01
except Exception: except Exception:
reward = 0 reward = 0.0
return reward return reward
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:

View file

@ -93,7 +93,7 @@ class SentenceReorderingDataset(ProceduralDataset):
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0 reward = 0.0
expected_answer = entry["answer"] expected_answer = entry["answer"]
if answer is not None: if answer is not None:
try: try:

View file

@ -50,7 +50,7 @@ class SpellBackwardDataset(ProceduralDataset):
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0 reward = 0.0
expected_answer = entry["answer"] expected_answer = entry["answer"]
if answer is not None: if answer is not None:
try: try: