use get_data_file_path to read file contents

This commit is contained in:
Zafir Stojanovski 2025-02-06 10:12:51 +01:00
parent 6ec6876221
commit 256eb71555

View file

@ -12,6 +12,7 @@ from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
MAX_ANAGRAM_GROUPS = 500
@ -55,7 +56,7 @@ class GroupAnagramsDataset(ProceduralDataset):
def __init__(self, config: GroupAnagramsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
with open("reasoning_gym/data/anagrams.jsonl") as f:
with get_data_file_path("anagrams.jsonl").open() as f:
self.anagrams = [json.loads(line)["words"] for line in f]
def __len__(self) -> int:
@ -105,14 +106,17 @@ class GroupAnagramsDataset(ProceduralDataset):
"""Score a single Group Anagrams question"""
reward = 0
if answer is not None:
answer = json.loads(answer)
oracle = entry["metadata"]["solution"]
answer_str = json.dumps(self._sort_nested_list(answer))
oracle_str = json.dumps(self._sort_nested_list(oracle))
if answer_str == oracle_str:
reward = 1
else:
reward = 0.01
try:
answer = json.loads(answer)
oracle = entry["metadata"]["solution"]
answer_str = json.dumps(self._sort_nested_list(answer))
oracle_str = json.dumps(self._sort_nested_list(oracle))
if answer_str == oracle_str:
reward = 1
else:
reward = 0.01
except Exception:
reward = 0
return reward
def __getitem__(self, idx: int) -> dict: