mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
updated configs
This commit is contained in:
parent
cc0bacd8e1
commit
7368d6d313
8 changed files with 158 additions and 315 deletions
|
|
@ -19,6 +19,7 @@ class SpellBackwardConfig:
|
|||
min_word_len: int = 3 # Minimum word length
|
||||
max_word_len: int = 10 # Maximum word length
|
||||
seed: Optional[int] = None
|
||||
data_file: str = "words3to10.txt"
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
|
|
@ -34,7 +35,7 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("words3to10.txt")
|
||||
text = read_data_file(self.config.data_file)
|
||||
self.words = [
|
||||
word.strip()
|
||||
for word in text.splitlines()
|
||||
|
|
@ -73,9 +74,9 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
if expected_answer == answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
answer_len = len(answer)
|
||||
answer_len = len(expected_answer)
|
||||
for i in range(len(expected_answer)):
|
||||
if (i < len(expected_answer) and i < len(answer)) and expected_answer[i] == answer[i]:
|
||||
if i < len(expected_answer) and i < len(answer):
|
||||
if expected_answer[i] == answer[i]:
|
||||
reward += 1 / answer_len
|
||||
else:
|
||||
|
|
@ -96,7 +97,7 @@ class SpellBackwardCurriculum(BaseCurriculum):
|
|||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="word_len",
|
||||
levels=list(range(3, 11)),
|
||||
levels=list(range(3, 10, 1)),
|
||||
description="Word length",
|
||||
lower_field_name="min_word_len",
|
||||
upper_field_name="max_word_len",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue