mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
feat(env): Word Sequence Reversal curriculum (#313)
* word sequence reversal curriculum * metadata
This commit is contained in:
parent
1f9ef02d4f
commit
6aa7547abd
3 changed files with 69 additions and 8 deletions
|
|
@ -5,10 +5,16 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Solve the following problem. Provide you answer as a comma-separated list of words with a space after the comma. Reverse this list of words: {question}"""
|
||||
QUESTION_TEMPLATE = """Solve the following problem.
|
||||
|
||||
Provide you answer as a comma-separated list of words with a space after the comma.
|
||||
|
||||
Reverse this list of words: {words}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -42,19 +48,43 @@ class WordSequenceReversalDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select random words
|
||||
num_words = rng.randint(self.config.min_words, self.config.max_words)
|
||||
num_words = min(
|
||||
rng.randint(self.config.min_words, self.config.max_words),
|
||||
len(self.words),
|
||||
)
|
||||
word_indices = rng.sample(range(len(self.words)), num_words)
|
||||
words = [self.words[i] for i in word_indices]
|
||||
|
||||
# Create question and answer
|
||||
question = ", ".join(words)
|
||||
words_str = ", ".join(words)
|
||||
answer = ", ".join(reversed(words))
|
||||
|
||||
return {
|
||||
"question": f"{QUESTION_TEMPLATE.format(question=question)}",
|
||||
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
|
||||
"answer": answer,
|
||||
"metadata": {"num_words": num_words, "words": words},
|
||||
"metadata": {"num_words": num_words, "words": words, "difficulty": {"words": num_words}},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig)
|
||||
class WordSequenceReversalCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(WordSequenceReversalCurriculum.__name__, WordSequenceReversalConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="words",
|
||||
levels=[10, 50, 100, 500],
|
||||
default_level=1,
|
||||
description="Number of words in the list",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_words",
|
||||
upper_field_name="max_words",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue