mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix(envs): Add source dataset and index to metadata (#388)
* add source dataset and index to metadata * fix typo * fix coach class and its test
This commit is contained in:
parent
7475a20700
commit
ce0a6c4878
104 changed files with 549 additions and 146 deletions
|
|
@ -5,6 +5,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "ab"
|
||||
|
||||
|
||||
def generate_program(length, rng):
|
||||
"""Generates a random initial program of a given length."""
|
||||
|
|
@ -116,9 +118,11 @@ Return the final state of the program.
|
|||
"question": prompt,
|
||||
"answer": " ".join(steps[-1]),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"difficulty": {
|
||||
"length": self.config.length,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -158,4 +162,4 @@ class ABCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("ab", ABDataset, ABConfig, ABCurriculum)
|
||||
register_dataset(DATASET_NAME, ABDataset, ABConfig, ABCurriculum)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ If the target base is > 10, use lowercase letters a-z for digits above 9.
|
|||
Now, convert the {source_name} number {source_repr} to {target_name}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "base_conversion"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConversionConfig:
|
||||
|
|
@ -104,6 +106,8 @@ class BaseConversionDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": target_repr,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"decimal_value": value,
|
||||
"source_base": source_base,
|
||||
"target_base": target_base,
|
||||
|
|
@ -142,4 +146,4 @@ class BaseConversionCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum)
|
||||
register_dataset(DATASET_NAME, BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ Now, determine the minimum number of swaps to make the following binary string a
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "binary_alternation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryAlternationConfig:
|
||||
"""Configuration for Count Bits dataset generation"""
|
||||
|
|
@ -105,6 +108,8 @@ class BinaryAlternationDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(string=string),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
|
|
@ -132,4 +137,4 @@ class BinaryAlternationCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)
|
||||
register_dataset(DATASET_NAME, BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ Find the distance to the nearest 0 for each cell in the matrix below:
|
|||
{matrix}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "binary_matrix"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryMatrixConfig:
|
||||
|
|
@ -128,6 +130,8 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"n": n,
|
||||
|
|
@ -160,4 +164,4 @@ class BinaryMatrixCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("binary_matrix", BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum)
|
||||
register_dataset(DATASET_NAME, BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
|||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "caesar_cipher"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaesarCipherConfig:
|
||||
|
|
@ -77,6 +79,8 @@ class CaesarCipherDataset(ProceduralDataset):
|
|||
"question": f"Decrypt this Caesar cipher text: {cipher_text}. Provide only the decrypted text as your final answer.",
|
||||
"answer": sentence,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"rotation": rotation,
|
||||
"cipher_text": cipher_text,
|
||||
"clear_text": sentence,
|
||||
|
|
@ -113,4 +117,4 @@ class CaesarCipherCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum)
|
||||
register_dataset(DATASET_NAME, CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
|
||||
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
|
||||
|
||||
DATASET_NAME = "count_primes"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountPrimesConfig:
|
||||
|
|
@ -60,6 +62,8 @@ class CountPrimesDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(start=start, end=end),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"primes": primes,
|
||||
|
|
@ -88,4 +92,4 @@ class CountPrimesCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum)
|
||||
register_dataset(DATASET_NAME, CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "cryptarithm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CryptarithmConfig:
|
||||
|
|
@ -51,9 +53,9 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = Random(self.seed + idx)
|
||||
return self._create_single_puzzle(rng)
|
||||
return self._create_single_puzzle(rng, idx)
|
||||
|
||||
def _create_single_puzzle(self, rng: Random) -> dict:
|
||||
def _create_single_puzzle(self, rng: Random, idx: int) -> dict:
|
||||
"""
|
||||
Creates one puzzle with N addends (2..3) plus a result.
|
||||
Ensures total distinct digits <= 10.
|
||||
|
|
@ -179,6 +181,8 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
"question": question_str,
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"letters": list(letter_to_digit.keys()),
|
||||
"word_values": words_numbers,
|
||||
"sum_number": total_sum,
|
||||
|
|
@ -260,4 +264,4 @@ class CryptarithmCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum)
|
||||
register_dataset(DATASET_NAME, CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import cellpylib as cpl
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "game_of_life"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameOfLifeConfig:
|
||||
|
|
@ -81,6 +83,8 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": result_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"grid_size_x": self.config.grid_size_x,
|
||||
"grid_size_y": self.config.grid_size_y,
|
||||
"filled_cells": self.config.filled_cells,
|
||||
|
|
@ -187,4 +191,4 @@ class GameOfLifeCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum)
|
||||
register_dataset(DATASET_NAME, GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import cellpylib as cpl
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "game_of_life_halting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameOfLifeHaltingConfig:
|
||||
|
|
@ -363,6 +365,8 @@ class GameOfLifeHaltingDataset(ProceduralDataset):
|
|||
"question": question,
|
||||
"answer": str(not should_oscillate),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"grid_size_x": grid_x,
|
||||
"grid_size_y": grid_y,
|
||||
"placed_patterns": placed_patterns,
|
||||
|
|
@ -438,4 +442,4 @@ class GameOfLifeHaltingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("game_of_life_halting", GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum)
|
||||
register_dataset(DATASET_NAME, GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "graph_color"
|
||||
|
||||
|
||||
def generate_random_graph(rng, num_vertices, edge_probability=0.3):
|
||||
"""
|
||||
|
|
@ -213,6 +215,8 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
|
|||
"question": question,
|
||||
"answer": None,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"possible_answer": solution,
|
||||
"puzzle": puzzle,
|
||||
"num_vertices": num_vertices,
|
||||
|
|
@ -272,4 +276,4 @@ class GraphColorCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("graph_color", GraphColorDataset, GraphColorConfig, GraphColorCurriculum)
|
||||
register_dataset(DATASET_NAME, GraphColorDataset, GraphColorConfig, GraphColorCurriculum)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ Group the following list of words into anagrams:
|
|||
{words}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "group_anagrams"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupAnagramsConfig:
|
||||
|
|
@ -115,6 +117,8 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(words=json.dumps(words)),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"words": words,
|
||||
"solution": answer,
|
||||
"anagram_groups": anagram_groups,
|
||||
|
|
@ -149,4 +153,4 @@ class GroupAnagramsCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum)
|
||||
register_dataset(DATASET_NAME, GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ Return True if the following two strings are isomorphic, or False otherwise:
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "isomorphic_strings"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IsomorphicStringsConfig:
|
||||
"""Configuration for Isomorphic Strings dataset generation"""
|
||||
|
|
@ -107,6 +110,8 @@ class IsomorphicStringsDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(s=s, t=t),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"words": [s, t],
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
|
|
@ -134,4 +139,4 @@ class IsomorphicStringsCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("isomorphic_strings", IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum)
|
||||
register_dataset(DATASET_NAME, IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "jugs"
|
||||
|
||||
|
||||
def min_moves_n(jug_capacities: list[int], target: int) -> Optional[int]:
|
||||
"""
|
||||
|
|
@ -282,6 +284,8 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
|
|||
"question": question,
|
||||
"answer": json.dumps(solution), # one possible solution
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"puzzle": puzzle,
|
||||
"difficulty": {
|
||||
"num_jugs": self.config.num_jugs,
|
||||
|
|
@ -340,4 +344,4 @@ class JugsCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("jugs", JugsDataset, JugsConfig, JugsCurriculum)
|
||||
register_dataset(DATASET_NAME, JugsDataset, JugsConfig, JugsCurriculum)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from reasoning_gym.data import read_data_file
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "letter_counting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterCountingConfig:
|
||||
|
|
@ -64,6 +66,8 @@ class LetterCountingDataset(ProceduralDataset):
|
|||
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
||||
"answer": str(count),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"span_length": span_length,
|
||||
"target_letter": target_letter,
|
||||
"span": span,
|
||||
|
|
@ -91,4 +95,4 @@ class LetterCountingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum)
|
||||
register_dataset(DATASET_NAME, LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ Now, unscramble these words: {words}
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "letter_jumble"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterJumbleConfig:
|
||||
"""Configuration for letter jumbling task generation"""
|
||||
|
|
@ -104,6 +107,8 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
|
||||
"answer": " ".join(selected_words),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"num_words": num_words,
|
||||
"corruption_level": corruption_level,
|
||||
"scrambled_words": scrambled_words,
|
||||
|
|
@ -193,4 +198,4 @@ class LetterJumbleCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum)
|
||||
register_dataset(DATASET_NAME, LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ Perform the following series of operations in order:
|
|||
{operations}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "manipulate_matrix"
|
||||
|
||||
|
||||
def num_rows(matrix: list[list[int]]) -> int:
|
||||
return len(matrix)
|
||||
|
|
@ -306,6 +308,8 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"operations": operations,
|
||||
|
|
@ -351,4 +355,4 @@ class ManipulateMatrixCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum)
|
||||
register_dataset(DATASET_NAME, ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from typing import Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "number_filtering"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberFilteringConfig:
|
||||
|
|
@ -91,6 +93,8 @@ class NumberFilteringDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": str(result_strs) if result_strs else "[]",
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"original_numbers": str_numbers,
|
||||
"filter_value": filter_str,
|
||||
"operation": f"{keep_remove}_{larger_smaller}",
|
||||
|
|
@ -138,4 +142,4 @@ class NumberFilteringCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)
|
||||
register_dataset(DATASET_NAME, NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "number_sorting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberSortingConfig:
|
||||
|
|
@ -90,6 +92,8 @@ Please follow the instruction below:
|
|||
"question": question,
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"original_numbers": number_strs,
|
||||
"direction": direction,
|
||||
"sorted_numbers": answer,
|
||||
|
|
@ -198,4 +202,4 @@ class NumberSortingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)
|
||||
register_dataset(DATASET_NAME, NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ Now, form a valid palindrome using the following letters: {letters}
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "palindrome_generation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PalindromeConfig:
|
||||
"""
|
||||
|
|
@ -67,6 +70,8 @@ class PalindromeDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)),
|
||||
"answer": palindrome,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"letters": scrambled_letters,
|
||||
"generated_palindrome": palindrome,
|
||||
"length": length,
|
||||
|
|
@ -138,4 +143,4 @@ class PalindromeCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig, PalindromeCurriculum)
|
||||
register_dataset(DATASET_NAME, PalindromeDataset, PalindromeConfig, PalindromeCurriculum)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ Your output should be a list of lists, where each list represents a palindrome p
|
|||
Partition the following string into palindromes: {string}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "palindrome_partitioning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PalindromePartitioningConfig:
|
||||
|
|
@ -138,6 +140,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(string=string),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"string_len": string_len,
|
||||
|
|
@ -176,7 +180,7 @@ class PalindromePartitioningCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
register_dataset(
|
||||
"palindrome_partitioning",
|
||||
DATASET_NAME,
|
||||
PalindromePartitioningDataset,
|
||||
PalindromePartitioningConfig,
|
||||
PalindromePartitioningCurriculum,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ Perform {pool_type} pooling on the following matrix with a kernel size of {pool_
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "pool_matrix"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolMatrixConfig:
|
||||
"""Configuration for Pool Matrix dataset generation"""
|
||||
|
|
@ -113,6 +116,8 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix.tolist(),
|
||||
"pool_type": pool_type,
|
||||
"pool_size": pool_size,
|
||||
|
|
@ -158,4 +163,4 @@ class PoolMatrixCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("pool_matrix", PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum)
|
||||
register_dataset(DATASET_NAME, PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ Ransom note: {ransom_note}
|
|||
Magazine: {magazine}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "ransom_note"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RansomNoteConfig:
|
||||
|
|
@ -99,6 +101,8 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"ransom_note": ransom_note,
|
||||
"magazine": magazine,
|
||||
"solution": answer,
|
||||
|
|
@ -136,4 +140,4 @@ class RansomNoteCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum)
|
||||
register_dataset(DATASET_NAME, RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ Rotate the matrix below by {degrees} degrees clockwise:
|
|||
{matrix}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "rotate_matrix"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RotateMatrixConfig:
|
||||
|
|
@ -83,6 +85,8 @@ class RotateMatrixDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix,
|
||||
"num_rotations": num_rotations,
|
||||
"solution": answer,
|
||||
|
|
@ -118,4 +122,4 @@ class RotateMatrixCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum)
|
||||
register_dataset(DATASET_NAME, RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ Now, determine the minimum number of minutes that must elapse until no cell in t
|
|||
{matrix}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "rotten_oranges"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RottenOrangesConfig:
|
||||
|
|
@ -120,6 +122,8 @@ class RottenOrangesDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"n": n,
|
||||
|
|
@ -146,4 +150,4 @@ class RottenOrangesCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum)
|
||||
register_dataset(DATASET_NAME, RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
|||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "sentence_reordering"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentenceReorderingConfig:
|
||||
|
|
@ -91,6 +93,8 @@ class SentenceReorderingDataset(ProceduralDataset):
|
|||
"question": f"Restore the correct order of words in the following sentence: {question}",
|
||||
"answer": solved_sentence,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"word_count": word_count,
|
||||
"difficulty": {
|
||||
"words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
|
||||
|
|
@ -137,6 +141,4 @@ class SentenceReorderingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum
|
||||
)
|
||||
register_dataset(DATASET_NAME, SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
|||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "spell_backward"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpellBackwardConfig:
|
||||
|
|
@ -52,6 +54,8 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
"question": f"Spell this word backward (example: sun -> nus): {word}",
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"word": word,
|
||||
"word_len": len(word),
|
||||
"difficulty": {
|
||||
|
|
@ -91,4 +95,4 @@ class SpellBackwardCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)
|
||||
register_dataset(DATASET_NAME, SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ For the matrix below, what is the list of elements in spiral order?
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "spiral_matrix"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpiralMatrixConfig:
|
||||
"""Configuration for Spiral Matrix dataset generation"""
|
||||
|
|
@ -111,6 +114,8 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"n": n,
|
||||
|
|
@ -158,4 +163,4 @@ class SpiralMatrixCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum)
|
||||
register_dataset(DATASET_NAME, SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ Given the following string, provide the answer after inserting the characters ac
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "string_insertion"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringInsertionConfig:
|
||||
"""Configuration for String Insertion dataset generation"""
|
||||
|
|
@ -102,6 +105,8 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(string=string),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"string_length": string_length,
|
||||
|
|
@ -129,4 +134,4 @@ class StringInsertionCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum)
|
||||
register_dataset(DATASET_NAME, StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ Transform the following string according to the above list of rules:
|
|||
{string}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "string_manipulation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringManipulationConfig:
|
||||
|
|
@ -179,6 +181,8 @@ class StringManipulationDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(string=string, rules=rules_str),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"states": states,
|
||||
|
|
@ -216,6 +220,4 @@ class StringManipulationCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"string_manipulation", StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum
|
||||
)
|
||||
register_dataset(DATASET_NAME, StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} mach
|
|||
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type.
|
||||
"""
|
||||
|
||||
DATASET_NAME = "string_splitting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringSplittingConfig:
|
||||
|
|
@ -125,6 +127,8 @@ class StringSplittingDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"initial_machines": (A_machine, B_machine, C_machine),
|
||||
|
|
@ -152,4 +156,4 @@ class StringSplittingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)
|
||||
register_dataset(DATASET_NAME, StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ Note: Apply the rules at most {max_iterations} times. If the rules cannot be app
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "string_synthesis"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringSynthesisConfig:
|
||||
"""Configuration for String Synthesis dataset generation"""
|
||||
|
|
@ -130,6 +133,8 @@ class StringSynthesisDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"initial_blocks": (A_square, B_square, C_square),
|
||||
|
|
@ -157,4 +162,4 @@ class StringSynthesisCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)
|
||||
register_dataset(DATASET_NAME, StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ Provide your answer as a comma-separated sequence of uppercase letters without s
|
|||
Each step must be a valid English word."""
|
||||
|
||||
|
||||
DATASET_NAME = "word_ladder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordLadderConfig:
|
||||
"""Configuration for word ladder task generation"""
|
||||
|
|
@ -219,6 +222,8 @@ class WordLadderDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(start=start, end=end),
|
||||
"answer": ",".join(path),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"start_word": start,
|
||||
"end_word": end,
|
||||
"word_length": length,
|
||||
|
|
@ -285,4 +290,4 @@ class WordLadderCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig, WordLadderCurriculum)
|
||||
register_dataset(DATASET_NAME, WordLadderDataset, WordLadderConfig, WordLadderCurriculum)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ Reverse this list of words: {words}
|
|||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "word_sequence_reversal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordSequenceReversalConfig:
|
||||
"""Configuration for word sequence reversal task generation"""
|
||||
|
|
@ -63,6 +66,8 @@ class WordSequenceReversalDataset(ProceduralDataset):
|
|||
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"num_words": num_words,
|
||||
"words": words,
|
||||
"difficulty": {
|
||||
|
|
@ -89,6 +94,4 @@ class WordSequenceReversalCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum
|
||||
)
|
||||
register_dataset(DATASET_NAME, WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ Your output should be a comma-separated list of words, e.g. word_1, word_2, word
|
|||
Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "word_sorting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordSortingConfig:
|
||||
|
|
@ -106,6 +108,8 @@ class WordSortingDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
|
||||
"answer": ", ".join(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"original_words": original_words,
|
||||
"sorted_words": answer,
|
||||
"transformed_words": transformed_words,
|
||||
|
|
@ -153,4 +157,4 @@ class WordSortingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)
|
||||
register_dataset(DATASET_NAME, WordSortingDataset, WordSortingConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue