diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 8011e375..973de976 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -2,7 +2,7 @@ import random from dataclasses import dataclass, field from enum import StrEnum from itertools import count -from typing import Optional +from typing import Any, Optional from ..factory import ProceduralDataset, register_dataset @@ -356,5 +356,19 @@ class FamilyRelationshipsDataset(ProceduralDataset): return " ".join(story_parts) + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + reward = 0.0 + if answer is not None: + try: + answer_formatted = answer.strip().lower() + solved = answer_formatted == entry["answer"].strip().lower() + if solved: + reward = 1.0 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)