From 256eb7155518b72c8a67bc5130fd86431b5aa575 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Thu, 6 Feb 2025 10:12:51 +0100 Subject: [PATCH] use `get_data_file_path` to read file contents --- reasoning_gym/algorithmic/group_anagrams.py | 22 ++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/reasoning_gym/algorithmic/group_anagrams.py b/reasoning_gym/algorithmic/group_anagrams.py index d76fb186..dba80777 100644 --- a/reasoning_gym/algorithmic/group_anagrams.py +++ b/reasoning_gym/algorithmic/group_anagrams.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from random import Random from typing import Dict, Optional +from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset MAX_ANAGRAM_GROUPS = 500 @@ -55,7 +56,7 @@ class GroupAnagramsDataset(ProceduralDataset): def __init__(self, config: GroupAnagramsConfig): super().__init__(config=config, seed=config.seed, size=config.size) - with open("reasoning_gym/data/anagrams.jsonl") as f: + with get_data_file_path("anagrams.jsonl").open() as f: self.anagrams = [json.loads(line)["words"] for line in f] def __len__(self) -> int: @@ -105,14 +106,17 @@ class GroupAnagramsDataset(ProceduralDataset): """Score a single Group Anagrams question""" reward = 0 if answer is not None: - answer = json.loads(answer) - oracle = entry["metadata"]["solution"] - answer_str = json.dumps(self._sort_nested_list(answer)) - oracle_str = json.dumps(self._sort_nested_list(oracle)) - if answer_str == oracle_str: - reward = 1 - else: - reward = 0.01 + try: + answer = json.loads(answer) + oracle = entry["metadata"]["solution"] + answer_str = json.dumps(self._sort_nested_list(answer)) + oracle_str = json.dumps(self._sort_nested_list(oracle)) + if answer_str == oracle_str: + reward = 1 + else: + reward = 0.01 + except Exception: + reward = 0 return reward def __getitem__(self, idx: int) -> dict: