mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
use get_data_file_path to read file contents
This commit is contained in:
parent
6ec6876221
commit
256eb71555
1 changed files with 13 additions and 9 deletions
|
|
@ -12,6 +12,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..data import get_data_file_path
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
MAX_ANAGRAM_GROUPS = 500
|
||||
|
|
@ -55,7 +56,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
|
||||
def __init__(self, config: GroupAnagramsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
with open("reasoning_gym/data/anagrams.jsonl") as f:
|
||||
with get_data_file_path("anagrams.jsonl").open() as f:
|
||||
self.anagrams = [json.loads(line)["words"] for line in f]
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -105,14 +106,17 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
"""Score a single Group Anagrams question"""
|
||||
reward = 0
|
||||
if answer is not None:
|
||||
answer = json.loads(answer)
|
||||
oracle = entry["metadata"]["solution"]
|
||||
answer_str = json.dumps(self._sort_nested_list(answer))
|
||||
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
||||
if answer_str == oracle_str:
|
||||
reward = 1
|
||||
else:
|
||||
reward = 0.01
|
||||
try:
|
||||
answer = json.loads(answer)
|
||||
oracle = entry["metadata"]["solution"]
|
||||
answer_str = json.dumps(self._sort_nested_list(answer))
|
||||
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
||||
if answer_str == oracle_str:
|
||||
reward = 1
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
reward = 0
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue