fix(envs): Add source dataset and index to metadata (#388)

* add source dataset and index to metadata

* fix typo

* fix coach class and its test
This commit is contained in:
Zafir Stojanovski 2025-03-20 12:12:14 +01:00 committed by GitHub
parent 7475a20700
commit ce0a6c4878
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
104 changed files with 549 additions and 146 deletions

View file

@ -5,6 +5,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "ab"
def generate_program(length, rng):
"""Generates a random initial program of a given length."""
@ -116,9 +118,11 @@ Return the final state of the program.
"question": prompt,
"answer": " ".join(steps[-1]),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"difficulty": {
"length": self.config.length,
}
},
},
}
@ -158,4 +162,4 @@ class ABCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("ab", ABDataset, ABConfig, ABCurriculum)
register_dataset(DATASET_NAME, ABDataset, ABConfig, ABCurriculum)

View file

@ -14,6 +14,8 @@ If the target base is > 10, use lowercase letters a-z for digits above 9.
Now, convert the {source_name} number {source_repr} to {target_name}
"""
DATASET_NAME = "base_conversion"
@dataclass
class BaseConversionConfig:
@ -104,6 +106,8 @@ class BaseConversionDataset(ProceduralDataset):
),
"answer": target_repr,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"decimal_value": value,
"source_base": source_base,
"target_base": target_base,
@ -142,4 +146,4 @@ class BaseConversionCurriculum(BaseCurriculum):
)
register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum)
register_dataset(DATASET_NAME, BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum)

View file

@ -20,6 +20,9 @@ Now, determine the minimum number of swaps to make the following binary string a
"""
DATASET_NAME = "binary_alternation"
@dataclass
class BinaryAlternationConfig:
"""Configuration for Count Bits dataset generation"""
@ -105,6 +108,8 @@ class BinaryAlternationDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(string=string),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"string": string,
"solution": answer,
"solvable": solvable,
@ -132,4 +137,4 @@ class BinaryAlternationCurriculum(BaseCurriculum):
)
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)
register_dataset(DATASET_NAME, BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)

View file

@ -20,6 +20,8 @@ Find the distance to the nearest 0 for each cell in the matrix below:
{matrix}
"""
DATASET_NAME = "binary_matrix"
@dataclass
class BinaryMatrixConfig:
@ -128,6 +130,8 @@ class BinaryMatrixDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix,
"solution": answer,
"n": n,
@ -160,4 +164,4 @@ class BinaryMatrixCurriculum(BaseCurriculum):
)
register_dataset("binary_matrix", BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum)
register_dataset(DATASET_NAME, BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum)

View file

@ -8,6 +8,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "caesar_cipher"
@dataclass
class CaesarCipherConfig:
@ -77,6 +79,8 @@ class CaesarCipherDataset(ProceduralDataset):
"question": f"Decrypt this Caesar cipher text: {cipher_text}. Provide only the decrypted text as your final answer.",
"answer": sentence,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"rotation": rotation,
"cipher_text": cipher_text,
"clear_text": sentence,
@ -113,4 +117,4 @@ class CaesarCipherCurriculum(BaseCurriculum):
)
register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum)
register_dataset(DATASET_NAME, CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum)

View file

@ -14,6 +14,8 @@ from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
DATASET_NAME = "count_primes"
@dataclass
class CountPrimesConfig:
@ -60,6 +62,8 @@ class CountPrimesDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"start": start,
"end": end,
"primes": primes,
@ -88,4 +92,4 @@ class CountPrimesCurriculum(BaseCurriculum):
)
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum)
register_dataset(DATASET_NAME, CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum)

View file

@ -18,6 +18,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "cryptarithm"
@dataclass
class CryptarithmConfig:
@ -51,9 +53,9 @@ class CryptarithmDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
return self._create_single_puzzle(rng)
return self._create_single_puzzle(rng, idx)
def _create_single_puzzle(self, rng: Random) -> dict:
def _create_single_puzzle(self, rng: Random, idx: int) -> dict:
"""
Creates one puzzle with N addends (2..3) plus a result.
Ensures total distinct digits <= 10.
@ -179,6 +181,8 @@ class CryptarithmDataset(ProceduralDataset):
"question": question_str,
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"letters": list(letter_to_digit.keys()),
"word_values": words_numbers,
"sum_number": total_sum,
@ -260,4 +264,4 @@ class CryptarithmCurriculum(BaseCurriculum):
)
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum)
register_dataset(DATASET_NAME, CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum)

View file

@ -8,6 +8,8 @@ import cellpylib as cpl
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "game_of_life"
@dataclass
class GameOfLifeConfig:
@ -81,6 +83,8 @@ class GameOfLifeDataset(ProceduralDataset):
),
"answer": result_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"grid_size_x": self.config.grid_size_x,
"grid_size_y": self.config.grid_size_y,
"filled_cells": self.config.filled_cells,
@ -187,4 +191,4 @@ class GameOfLifeCurriculum(BaseCurriculum):
)
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum)
register_dataset(DATASET_NAME, GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum)

View file

@ -7,6 +7,8 @@ import cellpylib as cpl
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "game_of_life_halting"
@dataclass
class GameOfLifeHaltingConfig:
@ -363,6 +365,8 @@ class GameOfLifeHaltingDataset(ProceduralDataset):
"question": question,
"answer": str(not should_oscillate),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"grid_size_x": grid_x,
"grid_size_y": grid_y,
"placed_patterns": placed_patterns,
@ -438,4 +442,4 @@ class GameOfLifeHaltingCurriculum(BaseCurriculum):
)
register_dataset("game_of_life_halting", GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum)
register_dataset(DATASET_NAME, GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum)

View file

@ -6,6 +6,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "graph_color"
def generate_random_graph(rng, num_vertices, edge_probability=0.3):
"""
@ -213,6 +215,8 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
"question": question,
"answer": None,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"possible_answer": solution,
"puzzle": puzzle,
"num_vertices": num_vertices,
@ -272,4 +276,4 @@ class GraphColorCurriculum(BaseCurriculum):
)
register_dataset("graph_color", GraphColorDataset, GraphColorConfig, GraphColorCurriculum)
register_dataset(DATASET_NAME, GraphColorDataset, GraphColorConfig, GraphColorCurriculum)

View file

@ -26,6 +26,8 @@ Group the following list of words into anagrams:
{words}
"""
DATASET_NAME = "group_anagrams"
@dataclass
class GroupAnagramsConfig:
@ -115,6 +117,8 @@ class GroupAnagramsDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(words=json.dumps(words)),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"words": words,
"solution": answer,
"anagram_groups": anagram_groups,
@ -149,4 +153,4 @@ class GroupAnagramsCurriculum(BaseCurriculum):
)
register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum)
register_dataset(DATASET_NAME, GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum)

View file

@ -24,6 +24,9 @@ Return True if the following two strings are isomorphic, or False otherwise:
"""
DATASET_NAME = "isomorphic_strings"
@dataclass
class IsomorphicStringsConfig:
"""Configuration for Isomorphic Strings dataset generation"""
@ -107,6 +110,8 @@ class IsomorphicStringsDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(s=s, t=t),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"words": [s, t],
"solution": answer,
"solvable": solvable,
@ -134,4 +139,4 @@ class IsomorphicStringsCurriculum(BaseCurriculum):
)
register_dataset("isomorphic_strings", IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum)
register_dataset(DATASET_NAME, IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum)

View file

@ -9,6 +9,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "jugs"
def min_moves_n(jug_capacities: list[int], target: int) -> Optional[int]:
"""
@ -282,6 +284,8 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
"question": question,
"answer": json.dumps(solution), # one possible solution
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"puzzle": puzzle,
"difficulty": {
"num_jugs": self.config.num_jugs,
@ -340,4 +344,4 @@ class JugsCurriculum(BaseCurriculum):
)
register_dataset("jugs", JugsDataset, JugsConfig, JugsCurriculum)
register_dataset(DATASET_NAME, JugsDataset, JugsConfig, JugsCurriculum)

View file

@ -10,6 +10,8 @@ from reasoning_gym.data import read_data_file
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "letter_counting"
@dataclass
class LetterCountingConfig:
@ -64,6 +66,8 @@ class LetterCountingDataset(ProceduralDataset):
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
"answer": str(count),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"span_length": span_length,
"target_letter": target_letter,
"span": span,
@ -91,4 +95,4 @@ class LetterCountingCurriculum(BaseCurriculum):
)
register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum)
register_dataset(DATASET_NAME, LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum)

View file

@ -22,6 +22,9 @@ Now, unscramble these words: {words}
"""
DATASET_NAME = "letter_jumble"
@dataclass
class LetterJumbleConfig:
"""Configuration for letter jumbling task generation"""
@ -104,6 +107,8 @@ class LetterJumbleDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
"answer": " ".join(selected_words),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num_words": num_words,
"corruption_level": corruption_level,
"scrambled_words": scrambled_words,
@ -193,4 +198,4 @@ class LetterJumbleCurriculum(BaseCurriculum):
)
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum)
register_dataset(DATASET_NAME, LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum)

View file

@ -18,6 +18,8 @@ Perform the following series of operations in order:
{operations}
"""
DATASET_NAME = "manipulate_matrix"
def num_rows(matrix: list[list[int]]) -> int:
return len(matrix)
@ -306,6 +308,8 @@ class ManipulateMatrixDataset(ProceduralDataset):
),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix,
"solution": answer,
"operations": operations,
@ -351,4 +355,4 @@ class ManipulateMatrixCurriculum(BaseCurriculum):
)
register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum)
register_dataset(DATASET_NAME, ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum)

View file

@ -7,6 +7,8 @@ from typing import Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "number_filtering"
@dataclass
class NumberFilteringConfig:
@ -91,6 +93,8 @@ class NumberFilteringDataset(ProceduralDataset):
),
"answer": str(result_strs) if result_strs else "[]",
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"original_numbers": str_numbers,
"filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}",
@ -138,4 +142,4 @@ class NumberFilteringCurriculum(BaseCurriculum):
)
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)
register_dataset(DATASET_NAME, NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)

View file

@ -8,6 +8,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "number_sorting"
@dataclass
class NumberSortingConfig:
@ -90,6 +92,8 @@ Please follow the instruction below:
"question": question,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"original_numbers": number_strs,
"direction": direction,
"sorted_numbers": answer,
@ -198,4 +202,4 @@ class NumberSortingCurriculum(BaseCurriculum):
)
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)
register_dataset(DATASET_NAME, NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum)

View file

@ -18,6 +18,9 @@ Now, form a valid palindrome using the following letters: {letters}
"""
DATASET_NAME = "palindrome_generation"
@dataclass
class PalindromeConfig:
"""
@ -67,6 +70,8 @@ class PalindromeDataset(ProceduralDataset):
"question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)),
"answer": palindrome,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"letters": scrambled_letters,
"generated_palindrome": palindrome,
"length": length,
@ -138,4 +143,4 @@ class PalindromeCurriculum(BaseCurriculum):
)
register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig, PalindromeCurriculum)
register_dataset(DATASET_NAME, PalindromeDataset, PalindromeConfig, PalindromeCurriculum)

View file

@ -24,6 +24,8 @@ Your output should be a list of lists, where each list represents a palindrome p
Partition the following string into palindromes: {string}
"""
DATASET_NAME = "palindrome_partitioning"
@dataclass
class PalindromePartitioningConfig:
@ -138,6 +140,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(string=string),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"string": string,
"solution": answer,
"string_len": string_len,
@ -176,7 +180,7 @@ class PalindromePartitioningCurriculum(BaseCurriculum):
register_dataset(
"palindrome_partitioning",
DATASET_NAME,
PalindromePartitioningDataset,
PalindromePartitioningConfig,
PalindromePartitioningCurriculum,

View file

@ -21,6 +21,9 @@ Perform {pool_type} pooling on the following matrix with a kernel size of {pool_
"""
DATASET_NAME = "pool_matrix"
@dataclass
class PoolMatrixConfig:
"""Configuration for Pool Matrix dataset generation"""
@ -113,6 +116,8 @@ class PoolMatrixDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix.tolist(),
"pool_type": pool_type,
"pool_size": pool_size,
@ -158,4 +163,4 @@ class PoolMatrixCurriculum(BaseCurriculum):
)
register_dataset("pool_matrix", PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum)
register_dataset(DATASET_NAME, PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum)

View file

@ -20,6 +20,8 @@ Ransom note: {ransom_note}
Magazine: {magazine}
"""
DATASET_NAME = "ransom_note"
@dataclass
class RansomNoteConfig:
@ -99,6 +101,8 @@ class RansomNoteDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"ransom_note": ransom_note,
"magazine": magazine,
"solution": answer,
@ -136,4 +140,4 @@ class RansomNoteCurriculum(BaseCurriculum):
)
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum)
register_dataset(DATASET_NAME, RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum)

View file

@ -20,6 +20,8 @@ Rotate the matrix below by {degrees} degrees clockwise:
{matrix}
"""
DATASET_NAME = "rotate_matrix"
@dataclass
class RotateMatrixConfig:
@ -83,6 +85,8 @@ class RotateMatrixDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix,
"num_rotations": num_rotations,
"solution": answer,
@ -118,4 +122,4 @@ class RotateMatrixCurriculum(BaseCurriculum):
)
register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum)
register_dataset(DATASET_NAME, RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum)

View file

@ -26,6 +26,8 @@ Now, determine the minimum number of minutes that must elapse until no cell in t
{matrix}
"""
DATASET_NAME = "rotten_oranges"
@dataclass
class RottenOrangesConfig:
@ -120,6 +122,8 @@ class RottenOrangesDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix,
"solution": answer,
"n": n,
@ -146,4 +150,4 @@ class RottenOrangesCurriculum(BaseCurriculum):
)
register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum)
register_dataset(DATASET_NAME, RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum)

View file

@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "sentence_reordering"
@dataclass
class SentenceReorderingConfig:
@ -91,6 +93,8 @@ class SentenceReorderingDataset(ProceduralDataset):
"question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"word_count": word_count,
"difficulty": {
"words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
@ -137,6 +141,4 @@ class SentenceReorderingCurriculum(BaseCurriculum):
)
register_dataset(
"sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum
)
register_dataset(DATASET_NAME, SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum)

View file

@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "spell_backward"
@dataclass
class SpellBackwardConfig:
@ -52,6 +54,8 @@ class SpellBackwardDataset(ProceduralDataset):
"question": f"Spell this word backward (example: sun -> nus): {word}",
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"word": word,
"word_len": len(word),
"difficulty": {
@ -91,4 +95,4 @@ class SpellBackwardCurriculum(BaseCurriculum):
)
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)
register_dataset(DATASET_NAME, SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)

View file

@ -27,6 +27,9 @@ For the matrix below, what is the list of elements in spiral order?
"""
DATASET_NAME = "spiral_matrix"
@dataclass
class SpiralMatrixConfig:
"""Configuration for Spiral Matrix dataset generation"""
@ -111,6 +114,8 @@ class SpiralMatrixDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"matrix": matrix,
"solution": answer,
"n": n,
@ -158,4 +163,4 @@ class SpiralMatrixCurriculum(BaseCurriculum):
)
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum)
register_dataset(DATASET_NAME, SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum)

View file

@ -25,6 +25,9 @@ Given the following string, provide the answer after inserting the characters ac
"""
DATASET_NAME = "string_insertion"
@dataclass
class StringInsertionConfig:
"""Configuration for String Insertion dataset generation"""
@ -102,6 +105,8 @@ class StringInsertionDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(string=string),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"string": string,
"solution": answer,
"string_length": string_length,
@ -129,4 +134,4 @@ class StringInsertionCurriculum(BaseCurriculum):
)
register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum)
register_dataset(DATASET_NAME, StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum)

View file

@ -24,6 +24,8 @@ Transform the following string according to the above list of rules:
{string}
"""
DATASET_NAME = "string_manipulation"
@dataclass
class StringManipulationConfig:
@ -179,6 +181,8 @@ class StringManipulationDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(string=string, rules=rules_str),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"string": string,
"solution": answer,
"states": states,
@ -216,6 +220,4 @@ class StringManipulationCurriculum(BaseCurriculum):
)
register_dataset(
"string_manipulation", StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum
)
register_dataset(DATASET_NAME, StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum)

View file

@ -28,6 +28,8 @@ Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} mach
Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type.
"""
DATASET_NAME = "string_splitting"
@dataclass
class StringSplittingConfig:
@ -125,6 +127,8 @@ class StringSplittingDataset(ProceduralDataset):
),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"states": states,
"solution": answer,
"initial_machines": (A_machine, B_machine, C_machine),
@ -152,4 +156,4 @@ class StringSplittingCurriculum(BaseCurriculum):
)
register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)
register_dataset(DATASET_NAME, StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum)

View file

@ -29,6 +29,9 @@ Note: Apply the rules at most {max_iterations} times. If the rules cannot be app
"""
DATASET_NAME = "string_synthesis"
@dataclass
class StringSynthesisConfig:
"""Configuration for String Synthesis dataset generation"""
@ -130,6 +133,8 @@ class StringSynthesisDataset(ProceduralDataset):
),
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"states": states,
"solution": answer,
"initial_blocks": (A_square, B_square, C_square),
@ -157,4 +162,4 @@ class StringSynthesisCurriculum(BaseCurriculum):
)
register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)
register_dataset(DATASET_NAME, StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum)

View file

@ -14,6 +14,9 @@ Provide your answer as a comma-separated sequence of uppercase letters without s
Each step must be a valid English word."""
DATASET_NAME = "word_ladder"
@dataclass
class WordLadderConfig:
"""Configuration for word ladder task generation"""
@ -219,6 +222,8 @@ class WordLadderDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": ",".join(path),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"start_word": start,
"end_word": end,
"word_length": length,
@ -285,4 +290,4 @@ class WordLadderCurriculum(BaseCurriculum):
)
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig, WordLadderCurriculum)
register_dataset(DATASET_NAME, WordLadderDataset, WordLadderConfig, WordLadderCurriculum)

View file

@ -17,6 +17,9 @@ Reverse this list of words: {words}
"""
DATASET_NAME = "word_sequence_reversal"
@dataclass
class WordSequenceReversalConfig:
"""Configuration for word sequence reversal task generation"""
@ -63,6 +66,8 @@ class WordSequenceReversalDataset(ProceduralDataset):
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num_words": num_words,
"words": words,
"difficulty": {
@ -89,6 +94,4 @@ class WordSequenceReversalCurriculum(BaseCurriculum):
)
register_dataset(
"word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum
)
register_dataset(DATASET_NAME, WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum)

View file

@ -27,6 +27,8 @@ Your output should be a comma-separated list of words, e.g. word_1, word_2, word
Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words}
"""
DATASET_NAME = "word_sorting"
@dataclass
class WordSortingConfig:
@ -106,6 +108,8 @@ class WordSortingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"original_words": original_words,
"sorted_words": answer,
"transformed_words": transformed_words,
@ -153,4 +157,4 @@ class WordSortingCurriculum(BaseCurriculum):
)
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)
register_dataset(DATASET_NAME, WordSortingDataset, WordSortingConfig)