mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
use get_data_file_path to read file contents
This commit is contained in:
parent
6ec6876221
commit
256eb71555
1 changed files with 13 additions and 9 deletions
|
|
@ -12,6 +12,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ..data import get_data_file_path
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
MAX_ANAGRAM_GROUPS = 500
|
MAX_ANAGRAM_GROUPS = 500
|
||||||
|
|
@ -55,7 +56,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: GroupAnagramsConfig):
|
def __init__(self, config: GroupAnagramsConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
with open("reasoning_gym/data/anagrams.jsonl") as f:
|
with get_data_file_path("anagrams.jsonl").open() as f:
|
||||||
self.anagrams = [json.loads(line)["words"] for line in f]
|
self.anagrams = [json.loads(line)["words"] for line in f]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -105,14 +106,17 @@ class GroupAnagramsDataset(ProceduralDataset):
|
||||||
"""Score a single Group Anagrams question"""
|
"""Score a single Group Anagrams question"""
|
||||||
reward = 0
|
reward = 0
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
answer = json.loads(answer)
|
try:
|
||||||
oracle = entry["metadata"]["solution"]
|
answer = json.loads(answer)
|
||||||
answer_str = json.dumps(self._sort_nested_list(answer))
|
oracle = entry["metadata"]["solution"]
|
||||||
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
answer_str = json.dumps(self._sort_nested_list(answer))
|
||||||
if answer_str == oracle_str:
|
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
||||||
reward = 1
|
if answer_str == oracle_str:
|
||||||
else:
|
reward = 1
|
||||||
reward = 0.01
|
else:
|
||||||
|
reward = 0.01
|
||||||
|
except Exception:
|
||||||
|
reward = 0
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue