diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py index 5c641ba2..fa4b0b34 100644 --- a/reasoning_gym/algebra/complex_arithmetic.py +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -7,6 +7,8 @@ from typing import Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "complex_arithmetic" + @dataclass class ComplexArithmeticConfig: @@ -90,6 +92,8 @@ class ComplexArithmeticDataset(ProceduralDataset): "question": question, "answer": self._format_complex(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num1": (a.real, a.imag), "num2": (b.real, b.imag), "operation": op, @@ -220,4 +224,4 @@ class ComplexArithmeticCurriculum(BaseCurriculum): ) -register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum) +register_dataset(DATASET_NAME, ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum) diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index bed3cd67..78daad06 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -7,6 +7,8 @@ import sympy from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "intermediate_integration" + @dataclass class IntermediateIntegrationConfig: @@ -235,6 +237,8 @@ Use same variable symbols as given in the question "question": question, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": index, "integrand": str(integrand), "problem_type": problem_type, "variable": str(x), @@ -292,7 +296,7 @@ class IntermediateIntegrationCurriculum(BaseCurriculum): register_dataset( - "intermediate_integration", + DATASET_NAME, IntermediateIntegrationDataset, IntermediateIntegrationConfig, IntermediateIntegrationCurriculum, diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 07ee0b6a..05f593a7 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -8,6 +8,8 @@ from sympy import Eq, Symbol, expand, solve from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "polynomial_equations" + @dataclass class PolynomialEquationsConfig: @@ -120,6 +122,8 @@ In solving equations, please follow these instructions: "question": question, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "polynomial_expr": str(polynomial_expanded), "variable": variable, "degree": degree, @@ -295,6 +299,4 @@ class PolynomialEquationsCurriculum(BaseCurriculum): ) -register_dataset( - "polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig, PolynomialEquationsCurriculum -) +register_dataset(DATASET_NAME, PolynomialEquationsDataset, PolynomialEquationsConfig, PolynomialEquationsCurriculum) diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index d3ba1c12..4a388532 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -8,6 +8,8 @@ from sympy.polys.monomials import itermonomials from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "polynomial_multiplication" + @dataclass class PolynomialMultiplicationConfig: @@ -109,6 +111,8 @@ When performing calculations, please follow these guidelines: "question": question, "answer": str(product), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "polynomial_expr": str(polynomial_expr), "variables": list(product.free_symbols), "difficulty": { @@ -230,7 +234,7 @@ class PolynomialMultiplicationCurriculum(BaseCurriculum): register_dataset( - "polynomial_multiplication", + DATASET_NAME, PolynomialMultiplicationDataset, PolynomialMultiplicationConfig, PolynomialMultiplicationCurriculum, diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 389e560d..ee9f81a8 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -8,6 +8,8 @@ from sympy import Symbol from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "simple_equations" + @dataclass class SimpleEquationsConfig: @@ -63,6 +65,8 @@ class SimpleEquationsDataset(ProceduralDataset): "question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation), "answer": str(solution), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "equation": equation, "variable": variable, "difficulty": { @@ -166,4 +170,4 @@ class SimpleEquationsCurriculum(BaseCurriculum): ) -register_dataset("simple_equations", SimpleEquationsDataset, SimpleEquationsConfig, SimpleEquationsCurriculum) +register_dataset(DATASET_NAME, SimpleEquationsDataset, SimpleEquationsConfig, SimpleEquationsCurriculum) diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index 16aa2d1b..78823ca8 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -8,6 +8,8 @@ import sympy from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "simple_integration" + @dataclass class SimpleIntegrationConfig: @@ -82,6 +84,8 @@ When performing calculations, please follow these guidelines: "question": question, "answer": str(polynomial) + " + C", "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "integrand": str(derivative), "variable": str(symbol), "expected_answer_expression": polynomial, @@ -128,4 +132,4 @@ class SimpleIntegrationCurriculum(BaseCurriculum): ) -register_dataset("simple_integration", SimpleIntegrationDataset, SimpleIntegrationConfig, SimpleIntegrationCurriculum) +register_dataset(DATASET_NAME, SimpleIntegrationDataset, SimpleIntegrationConfig, SimpleIntegrationCurriculum) diff --git a/reasoning_gym/algorithmic/ab.py b/reasoning_gym/algorithmic/ab.py index 9dad7f0f..b2d2e4e5 100644 --- a/reasoning_gym/algorithmic/ab.py +++ b/reasoning_gym/algorithmic/ab.py @@ -5,6 +5,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "ab" + def generate_program(length, rng): """Generates a random initial program of a given length.""" @@ -116,9 +118,11 @@ Return the final state of the program. "question": prompt, "answer": " ".join(steps[-1]), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "difficulty": { "length": self.config.length, - } + }, }, } @@ -158,4 +162,4 @@ class ABCurriculum(BaseCurriculum): # Register the dataset -register_dataset("ab", ABDataset, ABConfig, ABCurriculum) +register_dataset(DATASET_NAME, ABDataset, ABConfig, ABCurriculum) diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index 80c9bb29..81de935f 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -14,6 +14,8 @@ If the target base is > 10, use lowercase letters a-z for digits above 9. Now, convert the {source_name} number {source_repr} to {target_name} """ +DATASET_NAME = "base_conversion" + @dataclass class BaseConversionConfig: @@ -104,6 +106,8 @@ class BaseConversionDataset(ProceduralDataset): ), "answer": target_repr, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "decimal_value": value, "source_base": source_base, "target_base": target_base, @@ -142,4 +146,4 @@ class BaseConversionCurriculum(BaseCurriculum): ) -register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum) +register_dataset(DATASET_NAME, BaseConversionDataset, BaseConversionConfig, BaseConversionCurriculum) diff --git a/reasoning_gym/algorithmic/binary_alternation.py b/reasoning_gym/algorithmic/binary_alternation.py index 986f97f1..80e7637c 100644 --- a/reasoning_gym/algorithmic/binary_alternation.py +++ b/reasoning_gym/algorithmic/binary_alternation.py @@ -20,6 +20,9 @@ Now, determine the minimum number of swaps to make the following binary string a """ +DATASET_NAME = "binary_alternation" + + @dataclass class BinaryAlternationConfig: """Configuration for Count Bits dataset generation""" @@ -105,6 +108,8 @@ class BinaryAlternationDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(string=string), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "string": string, "solution": answer, "solvable": solvable, @@ -132,4 +137,4 @@ class BinaryAlternationCurriculum(BaseCurriculum): ) -register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum) +register_dataset(DATASET_NAME, BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum) diff --git a/reasoning_gym/algorithmic/binary_matrix.py b/reasoning_gym/algorithmic/binary_matrix.py index 1c7db03b..49b97382 100644 --- a/reasoning_gym/algorithmic/binary_matrix.py +++ b/reasoning_gym/algorithmic/binary_matrix.py @@ -20,6 +20,8 @@ Find the distance to the nearest 0 for each cell in the matrix below: {matrix} """ +DATASET_NAME = "binary_matrix" + @dataclass class BinaryMatrixConfig: @@ -128,6 +130,8 @@ class BinaryMatrixDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(matrix=matrix_str), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "solution": answer, "n": n, @@ -160,4 +164,4 @@ class BinaryMatrixCurriculum(BaseCurriculum): ) -register_dataset("binary_matrix", BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum) +register_dataset(DATASET_NAME, BinaryMatrixDataset, BinaryMatrixConfig, BinaryMatrixCurriculum) diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py index e28c1dc4..753dbd34 100644 --- a/reasoning_gym/algorithmic/caesar_cipher.py +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -8,6 +8,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "caesar_cipher" + @dataclass class CaesarCipherConfig: @@ -77,6 +79,8 @@ class CaesarCipherDataset(ProceduralDataset): "question": f"Decrypt this Caesar cipher text: {cipher_text}. Provide only the decrypted text as your final answer.", "answer": sentence, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "rotation": rotation, "cipher_text": cipher_text, "clear_text": sentence, @@ -113,4 +117,4 @@ class CaesarCipherCurriculum(BaseCurriculum): ) -register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum) +register_dataset(DATASET_NAME, CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum) diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py index 424b0884..066d603c 100644 --- a/reasoning_gym/algorithmic/count_primes.py +++ b/reasoning_gym/algorithmic/count_primes.py @@ -14,6 +14,8 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?""" +DATASET_NAME = "count_primes" + @dataclass class CountPrimesConfig: @@ -60,6 +62,8 @@ class CountPrimesDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "start": start, "end": end, "primes": primes, @@ -88,4 +92,4 @@ class CountPrimesCurriculum(BaseCurriculum): ) -register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum) +register_dataset(DATASET_NAME, CountPrimesDataset, CountPrimesConfig, CountPrimesCurriculum) diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index c989960d..03bbe9ff 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -18,6 +18,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "cryptarithm" + @dataclass class CryptarithmConfig: @@ -51,9 +53,9 @@ class CryptarithmDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: rng = Random(self.seed + idx) - return self._create_single_puzzle(rng) + return self._create_single_puzzle(rng, idx) - def _create_single_puzzle(self, rng: Random) -> dict: + def _create_single_puzzle(self, rng: Random, idx: int) -> dict: """ Creates one puzzle with N addends (2..3) plus a result. Ensures total distinct digits <= 10. @@ -179,6 +181,8 @@ class CryptarithmDataset(ProceduralDataset): "question": question_str, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "letters": list(letter_to_digit.keys()), "word_values": words_numbers, "sum_number": total_sum, @@ -260,4 +264,4 @@ class CryptarithmCurriculum(BaseCurriculum): ) -register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum) +register_dataset(DATASET_NAME, CryptarithmDataset, CryptarithmConfig, CryptarithmCurriculum) diff --git a/reasoning_gym/algorithmic/game_of_life.py b/reasoning_gym/algorithmic/game_of_life.py index f1c2fa78..d2501019 100644 --- a/reasoning_gym/algorithmic/game_of_life.py +++ b/reasoning_gym/algorithmic/game_of_life.py @@ -8,6 +8,8 @@ import cellpylib as cpl from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "game_of_life" + @dataclass class GameOfLifeConfig: @@ -81,6 +83,8 @@ class GameOfLifeDataset(ProceduralDataset): ), "answer": result_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "grid_size_x": self.config.grid_size_x, "grid_size_y": self.config.grid_size_y, "filled_cells": self.config.filled_cells, @@ -187,4 +191,4 @@ class GameOfLifeCurriculum(BaseCurriculum): ) -register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum) +register_dataset(DATASET_NAME, GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum) diff --git a/reasoning_gym/algorithmic/game_of_life_halting.py b/reasoning_gym/algorithmic/game_of_life_halting.py index 51b2bc3a..85d35b8d 100644 --- a/reasoning_gym/algorithmic/game_of_life_halting.py +++ b/reasoning_gym/algorithmic/game_of_life_halting.py @@ -7,6 +7,8 @@ import cellpylib as cpl from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "game_of_life_halting" + @dataclass class GameOfLifeHaltingConfig: @@ -363,6 +365,8 @@ class GameOfLifeHaltingDataset(ProceduralDataset): "question": question, "answer": str(not should_oscillate), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "grid_size_x": grid_x, "grid_size_y": grid_y, "placed_patterns": placed_patterns, @@ -438,4 +442,4 @@ class GameOfLifeHaltingCurriculum(BaseCurriculum): ) -register_dataset("game_of_life_halting", GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum) +register_dataset(DATASET_NAME, GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum) diff --git a/reasoning_gym/algorithmic/graph_color.py b/reasoning_gym/algorithmic/graph_color.py index 749f3b45..dd1ed32f 100644 --- a/reasoning_gym/algorithmic/graph_color.py +++ b/reasoning_gym/algorithmic/graph_color.py @@ -6,6 +6,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "graph_color" + def generate_random_graph(rng, num_vertices, edge_probability=0.3): """ @@ -213,6 +215,8 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1 "question": question, "answer": None, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "possible_answer": solution, "puzzle": puzzle, "num_vertices": num_vertices, @@ -272,4 +276,4 @@ class GraphColorCurriculum(BaseCurriculum): ) -register_dataset("graph_color", GraphColorDataset, GraphColorConfig, GraphColorCurriculum) +register_dataset(DATASET_NAME, GraphColorDataset, GraphColorConfig, GraphColorCurriculum) diff --git a/reasoning_gym/algorithmic/group_anagrams.py b/reasoning_gym/algorithmic/group_anagrams.py index 6b485e9e..ee3672f0 100644 --- a/reasoning_gym/algorithmic/group_anagrams.py +++ b/reasoning_gym/algorithmic/group_anagrams.py @@ -26,6 +26,8 @@ Group the following list of words into anagrams: {words} """ +DATASET_NAME = "group_anagrams" + @dataclass class GroupAnagramsConfig: @@ -115,6 +117,8 @@ class GroupAnagramsDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(words=json.dumps(words)), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "words": words, "solution": answer, "anagram_groups": anagram_groups, @@ -149,4 +153,4 @@ class GroupAnagramsCurriculum(BaseCurriculum): ) -register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum) +register_dataset(DATASET_NAME, GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum) diff --git a/reasoning_gym/algorithmic/isomorphic_strings.py b/reasoning_gym/algorithmic/isomorphic_strings.py index f2e2fb49..39b0cb8b 100644 --- a/reasoning_gym/algorithmic/isomorphic_strings.py +++ b/reasoning_gym/algorithmic/isomorphic_strings.py @@ -24,6 +24,9 @@ Return True if the following two strings are isomorphic, or False otherwise: """ +DATASET_NAME = "isomorphic_strings" + + @dataclass class IsomorphicStringsConfig: """Configuration for Isomorphic Strings dataset generation""" @@ -107,6 +110,8 @@ class IsomorphicStringsDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(s=s, t=t), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "words": [s, t], "solution": answer, "solvable": solvable, @@ -134,4 +139,4 @@ class IsomorphicStringsCurriculum(BaseCurriculum): ) -register_dataset("isomorphic_strings", IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum) +register_dataset(DATASET_NAME, IsomorphicStringsDataset, IsomorphicStringsConfig, IsomorphicStringsCurriculum) diff --git a/reasoning_gym/algorithmic/jugs.py b/reasoning_gym/algorithmic/jugs.py index 3dbf84d4..9b652ccc 100644 --- a/reasoning_gym/algorithmic/jugs.py +++ b/reasoning_gym/algorithmic/jugs.py @@ -9,6 +9,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "jugs" + def min_moves_n(jug_capacities: list[int], target: int) -> Optional[int]: """ @@ -282,6 +284,8 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil "question": question, "answer": json.dumps(solution), # one possible solution "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "difficulty": { "num_jugs": self.config.num_jugs, @@ -340,4 +344,4 @@ class JugsCurriculum(BaseCurriculum): ) -register_dataset("jugs", JugsDataset, JugsConfig, JugsCurriculum) +register_dataset(DATASET_NAME, JugsDataset, JugsConfig, JugsCurriculum) diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 197596e5..5d95851c 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -10,6 +10,8 @@ from reasoning_gym.data import read_data_file from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "letter_counting" + @dataclass class LetterCountingConfig: @@ -64,6 +66,8 @@ class LetterCountingDataset(ProceduralDataset): "question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?', "answer": str(count), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "span_length": span_length, "target_letter": target_letter, "span": span, @@ -91,4 +95,4 @@ class LetterCountingCurriculum(BaseCurriculum): ) -register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum) +register_dataset(DATASET_NAME, LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum) diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 7f56dfa8..83996512 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -22,6 +22,9 @@ Now, unscramble these words: {words} """ +DATASET_NAME = "letter_jumble" + + @dataclass class LetterJumbleConfig: """Configuration for letter jumbling task generation""" @@ -104,6 +107,8 @@ class LetterJumbleDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)), "answer": " ".join(selected_words), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num_words": num_words, "corruption_level": corruption_level, "scrambled_words": scrambled_words, @@ -193,4 +198,4 @@ class LetterJumbleCurriculum(BaseCurriculum): ) -register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum) +register_dataset(DATASET_NAME, LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum) diff --git a/reasoning_gym/algorithmic/manipulate_matrix.py b/reasoning_gym/algorithmic/manipulate_matrix.py index 49fe4615..d703be8c 100644 --- a/reasoning_gym/algorithmic/manipulate_matrix.py +++ b/reasoning_gym/algorithmic/manipulate_matrix.py @@ -18,6 +18,8 @@ Perform the following series of operations in order: {operations} """ +DATASET_NAME = "manipulate_matrix" + def num_rows(matrix: list[list[int]]) -> int: return len(matrix) @@ -306,6 +308,8 @@ class ManipulateMatrixDataset(ProceduralDataset): ), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "solution": answer, "operations": operations, @@ -351,4 +355,4 @@ class ManipulateMatrixCurriculum(BaseCurriculum): ) -register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum) +register_dataset(DATASET_NAME, ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum) diff --git a/reasoning_gym/algorithmic/number_filtering.py b/reasoning_gym/algorithmic/number_filtering.py index 0561cd7d..8c766174 100644 --- a/reasoning_gym/algorithmic/number_filtering.py +++ b/reasoning_gym/algorithmic/number_filtering.py @@ -7,6 +7,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "number_filtering" + @dataclass class NumberFilteringConfig: @@ -91,6 +93,8 @@ class NumberFilteringDataset(ProceduralDataset): ), "answer": str(result_strs) if result_strs else "[]", "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "original_numbers": str_numbers, "filter_value": filter_str, "operation": f"{keep_remove}_{larger_smaller}", @@ -138,4 +142,4 @@ class NumberFilteringCurriculum(BaseCurriculum): ) -register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum) +register_dataset(DATASET_NAME, NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum) diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index 35664b8e..cc3e06d2 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -8,6 +8,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "number_sorting" + @dataclass class NumberSortingConfig: @@ -90,6 +92,8 @@ Please follow the instruction below: "question": question, "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "original_numbers": number_strs, "direction": direction, "sorted_numbers": answer, @@ -198,4 +202,4 @@ class NumberSortingCurriculum(BaseCurriculum): ) -register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum) +register_dataset(DATASET_NAME, NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum) diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py index 92236b9a..9385c089 100644 --- a/reasoning_gym/algorithmic/palindrome_generation.py +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -18,6 +18,9 @@ Now, form a valid palindrome using the following letters: {letters} """ +DATASET_NAME = "palindrome_generation" + + @dataclass class PalindromeConfig: """ @@ -67,6 +70,8 @@ class PalindromeDataset(ProceduralDataset): "question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)), "answer": palindrome, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "letters": scrambled_letters, "generated_palindrome": palindrome, "length": length, @@ -138,4 +143,4 @@ class PalindromeCurriculum(BaseCurriculum): ) -register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig, PalindromeCurriculum) +register_dataset(DATASET_NAME, PalindromeDataset, PalindromeConfig, PalindromeCurriculum) diff --git a/reasoning_gym/algorithmic/palindrome_partitioning.py b/reasoning_gym/algorithmic/palindrome_partitioning.py index a7789a12..48f8a55c 100644 --- a/reasoning_gym/algorithmic/palindrome_partitioning.py +++ b/reasoning_gym/algorithmic/palindrome_partitioning.py @@ -24,6 +24,8 @@ Your output should be a list of lists, where each list represents a palindrome p Partition the following string into palindromes: {string} """ +DATASET_NAME = "palindrome_partitioning" + @dataclass class PalindromePartitioningConfig: @@ -138,6 +140,8 @@ class PalindromePartitioningDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(string=string), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "string": string, "solution": answer, "string_len": string_len, @@ -176,7 +180,7 @@ class PalindromePartitioningCurriculum(BaseCurriculum): register_dataset( - "palindrome_partitioning", + DATASET_NAME, PalindromePartitioningDataset, PalindromePartitioningConfig, PalindromePartitioningCurriculum, diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index 3a703d83..66a17fee 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -21,6 +21,9 @@ Perform {pool_type} pooling on the following matrix with a kernel size of {pool_ """ +DATASET_NAME = "pool_matrix" + + @dataclass class PoolMatrixConfig: """Configuration for Pool Matrix dataset generation""" @@ -113,6 +116,8 @@ class PoolMatrixDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix.tolist(), "pool_type": pool_type, "pool_size": pool_size, @@ -158,4 +163,4 @@ class PoolMatrixCurriculum(BaseCurriculum): ) -register_dataset("pool_matrix", PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum) +register_dataset(DATASET_NAME, PoolMatrixDataset, PoolMatrixConfig, PoolMatrixCurriculum) diff --git a/reasoning_gym/algorithmic/ransom_note.py b/reasoning_gym/algorithmic/ransom_note.py index 5e32c0e9..5c1aab6c 100644 --- a/reasoning_gym/algorithmic/ransom_note.py +++ b/reasoning_gym/algorithmic/ransom_note.py @@ -20,6 +20,8 @@ Ransom note: {ransom_note} Magazine: {magazine} """ +DATASET_NAME = "ransom_note" + @dataclass class RansomNoteConfig: @@ -99,6 +101,8 @@ class RansomNoteDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "ransom_note": ransom_note, "magazine": magazine, "solution": answer, @@ -136,4 +140,4 @@ class RansomNoteCurriculum(BaseCurriculum): ) -register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum) +register_dataset(DATASET_NAME, RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum) diff --git a/reasoning_gym/algorithmic/rotate_matrix.py b/reasoning_gym/algorithmic/rotate_matrix.py index 0716e656..8954d08a 100644 --- a/reasoning_gym/algorithmic/rotate_matrix.py +++ b/reasoning_gym/algorithmic/rotate_matrix.py @@ -20,6 +20,8 @@ Rotate the matrix below by {degrees} degrees clockwise: {matrix} """ +DATASET_NAME = "rotate_matrix" + @dataclass class RotateMatrixConfig: @@ -83,6 +85,8 @@ class RotateMatrixDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "num_rotations": num_rotations, "solution": answer, @@ -118,4 +122,4 @@ class RotateMatrixCurriculum(BaseCurriculum): ) -register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum) +register_dataset(DATASET_NAME, RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum) diff --git a/reasoning_gym/algorithmic/rotten_oranges.py b/reasoning_gym/algorithmic/rotten_oranges.py index f2dabadf..f10fb53d 100644 --- a/reasoning_gym/algorithmic/rotten_oranges.py +++ b/reasoning_gym/algorithmic/rotten_oranges.py @@ -26,6 +26,8 @@ Now, determine the minimum number of minutes that must elapse until no cell in t {matrix} """ +DATASET_NAME = "rotten_oranges" + @dataclass class RottenOrangesConfig: @@ -120,6 +122,8 @@ class RottenOrangesDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(matrix=matrix_str), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "solution": answer, "n": n, @@ -146,4 +150,4 @@ class RottenOrangesCurriculum(BaseCurriculum): ) -register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum) +register_dataset(DATASET_NAME, RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum) diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index 29402b44..77d14d72 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "sentence_reordering" + @dataclass class SentenceReorderingConfig: @@ -91,6 +93,8 @@ class SentenceReorderingDataset(ProceduralDataset): "question": f"Restore the correct order of words in the following sentence: {question}", "answer": solved_sentence, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "word_count": word_count, "difficulty": { "words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence), @@ -137,6 +141,4 @@ class SentenceReorderingCurriculum(BaseCurriculum): ) -register_dataset( - "sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum -) +register_dataset(DATASET_NAME, SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum) diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index e4008038..2fed1d22 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -9,6 +9,8 @@ from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "spell_backward" + @dataclass class SpellBackwardConfig: @@ -52,6 +54,8 @@ class SpellBackwardDataset(ProceduralDataset): "question": f"Spell this word backward (example: sun -> nus): {word}", "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "word": word, "word_len": len(word), "difficulty": { @@ -91,4 +95,4 @@ class SpellBackwardCurriculum(BaseCurriculum): ) -register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum) +register_dataset(DATASET_NAME, SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum) diff --git a/reasoning_gym/algorithmic/spiral_matrix.py b/reasoning_gym/algorithmic/spiral_matrix.py index 23fb6e37..651e5d69 100644 --- a/reasoning_gym/algorithmic/spiral_matrix.py +++ b/reasoning_gym/algorithmic/spiral_matrix.py @@ -27,6 +27,9 @@ For the matrix below, what is the list of elements in spiral order? """ +DATASET_NAME = "spiral_matrix" + + @dataclass class SpiralMatrixConfig: """Configuration for Spiral Matrix dataset generation""" @@ -111,6 +114,8 @@ class SpiralMatrixDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(matrix=matrix_str), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "solution": answer, "n": n, @@ -158,4 +163,4 @@ class SpiralMatrixCurriculum(BaseCurriculum): ) -register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum) +register_dataset(DATASET_NAME, SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum) diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index 3c707bf2..dd86c0fc 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -25,6 +25,9 @@ Given the following string, provide the answer after inserting the characters ac """ +DATASET_NAME = "string_insertion" + + @dataclass class StringInsertionConfig: """Configuration for String Insertion dataset generation""" @@ -102,6 +105,8 @@ class StringInsertionDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(string=string), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "string": string, "solution": answer, "string_length": string_length, @@ -129,4 +134,4 @@ class StringInsertionCurriculum(BaseCurriculum): ) -register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum) +register_dataset(DATASET_NAME, StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum) diff --git a/reasoning_gym/algorithmic/string_manipulation.py b/reasoning_gym/algorithmic/string_manipulation.py index fb6e3e33..c8310c8d 100644 --- a/reasoning_gym/algorithmic/string_manipulation.py +++ b/reasoning_gym/algorithmic/string_manipulation.py @@ -24,6 +24,8 @@ Transform the following string according to the above list of rules: {string} """ +DATASET_NAME = "string_manipulation" + @dataclass class StringManipulationConfig: @@ -179,6 +181,8 @@ class StringManipulationDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(string=string, rules=rules_str), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "string": string, "solution": answer, "states": states, @@ -216,6 +220,4 @@ class StringManipulationCurriculum(BaseCurriculum): ) -register_dataset( - "string_manipulation", StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum -) +register_dataset(DATASET_NAME, StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum) diff --git a/reasoning_gym/algorithmic/string_splitting.py b/reasoning_gym/algorithmic/string_splitting.py index 4ac1ed43..8bd8192d 100644 --- a/reasoning_gym/algorithmic/string_splitting.py +++ b/reasoning_gym/algorithmic/string_splitting.py @@ -28,6 +28,8 @@ Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} mach Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type. """ +DATASET_NAME = "string_splitting" + @dataclass class StringSplittingConfig: @@ -125,6 +127,8 @@ class StringSplittingDataset(ProceduralDataset): ), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "states": states, "solution": answer, "initial_machines": (A_machine, B_machine, C_machine), @@ -152,4 +156,4 @@ class StringSplittingCurriculum(BaseCurriculum): ) -register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum) +register_dataset(DATASET_NAME, StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum) diff --git a/reasoning_gym/algorithmic/string_synthesis.py b/reasoning_gym/algorithmic/string_synthesis.py index 2edc805b..f201854d 100644 --- a/reasoning_gym/algorithmic/string_synthesis.py +++ b/reasoning_gym/algorithmic/string_synthesis.py @@ -29,6 +29,9 @@ Note: Apply the rules at most {max_iterations} times. If the rules cannot be app """ +DATASET_NAME = "string_synthesis" + + @dataclass class StringSynthesisConfig: """Configuration for String Synthesis dataset generation""" @@ -130,6 +133,8 @@ class StringSynthesisDataset(ProceduralDataset): ), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "states": states, "solution": answer, "initial_blocks": (A_square, B_square, C_square), @@ -157,4 +162,4 @@ class StringSynthesisCurriculum(BaseCurriculum): ) -register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum) +register_dataset(DATASET_NAME, StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 9d3d8429..02b4eb9f 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -14,6 +14,9 @@ Provide your answer as a comma-separated sequence of uppercase letters without s Each step must be a valid English word.""" +DATASET_NAME = "word_ladder" + + @dataclass class WordLadderConfig: """Configuration for word ladder task generation""" @@ -219,6 +222,8 @@ class WordLadderDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": ",".join(path), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "start_word": start, "end_word": end, "word_length": length, @@ -285,4 +290,4 @@ class WordLadderCurriculum(BaseCurriculum): ) -register_dataset("word_ladder", WordLadderDataset, WordLadderConfig, WordLadderCurriculum) +register_dataset(DATASET_NAME, WordLadderDataset, WordLadderConfig, WordLadderCurriculum) diff --git a/reasoning_gym/algorithmic/word_sequence_reversal.py b/reasoning_gym/algorithmic/word_sequence_reversal.py index b7b41b08..b9cec440 100644 --- a/reasoning_gym/algorithmic/word_sequence_reversal.py +++ b/reasoning_gym/algorithmic/word_sequence_reversal.py @@ -17,6 +17,9 @@ Reverse this list of words: {words} """ +DATASET_NAME = "word_sequence_reversal" + + @dataclass class WordSequenceReversalConfig: """Configuration for word sequence reversal task generation""" @@ -63,6 +66,8 @@ class WordSequenceReversalDataset(ProceduralDataset): "question": f"{QUESTION_TEMPLATE.format(words=words_str)}", "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num_words": num_words, "words": words, "difficulty": { @@ -89,6 +94,4 @@ class WordSequenceReversalCurriculum(BaseCurriculum): ) -register_dataset( - "word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum -) +register_dataset(DATASET_NAME, WordSequenceReversalDataset, WordSequenceReversalConfig, WordSequenceReversalCurriculum) diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index fd514cc7..1fc20e28 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -27,6 +27,8 @@ Your output should be a comma-separated list of words, e.g. word_1, word_2, word Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words} """ +DATASET_NAME = "word_sorting" + @dataclass class WordSortingConfig: @@ -106,6 +108,8 @@ class WordSortingDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)), "answer": ", ".join(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "original_words": original_words, "sorted_words": answer, "transformed_words": transformed_words, @@ -153,4 +157,4 @@ class WordSortingCurriculum(BaseCurriculum): ) -register_dataset("word_sorting", WordSortingDataset, WordSortingConfig) +register_dataset(DATASET_NAME, WordSortingDataset, WordSortingConfig) diff --git a/reasoning_gym/arc/arc_1d.py b/reasoning_gym/arc/arc_1d.py index 17412b81..9a7ee78c 100644 --- a/reasoning_gym/arc/arc_1d.py +++ b/reasoning_gym/arc/arc_1d.py @@ -5,6 +5,8 @@ from typing import Optional from ..dataset import ProceduralDataset from ..factory import register_dataset +DATASET_NAME = "arc_1d" + @dataclass class Arc1DConfig: @@ -100,6 +102,8 @@ class Arc1DDataset(ProceduralDataset): "question": question, "answer": " ".join(str(x) for x in test_example["output"]), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "task_name": task_name, "size": size, "train_examples": train_examples, @@ -109,4 +113,4 @@ class Arc1DDataset(ProceduralDataset): # Register the dataset -register_dataset("arc_1d", Arc1DDataset, Arc1DConfig) +register_dataset(DATASET_NAME, Arc1DDataset, Arc1DConfig) diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index a19a6507..b46a7091 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -14,6 +14,8 @@ from reasoning_gym.arc.board_format import ( from reasoning_gym.dataset import ProceduralDataset from reasoning_gym.factory import register_dataset +DATASET_NAME = "arc_agi" + @dataclass class ArcAgiConfig: @@ -182,6 +184,8 @@ class ArcAgiDataset(ProceduralDataset): "question": input_prompt, "answer": test_output, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "input": totuple(augmented_test_input), "output": totuple(augmented_test_output), "task_id": task_id, @@ -203,4 +207,4 @@ class ArcAgiDataset(ProceduralDataset): return reward -register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig) +register_dataset(DATASET_NAME, ArcAgiDataset, ArcAgiConfig) diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index 83adedbf..e13cdbb2 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -16,6 +16,8 @@ PSO_DIFFICULTY_RANGES = [ (PSO_DIFFICULTY_LEVELS[i], PSO_DIFFICULTY_LEVELS[i + 1]) for i in range(len(PSO_DIFFICULTY_LEVELS) - 1) ] +DATASET_NAME = "rearc" + @dataclass class ReArcConfig: @@ -114,6 +116,8 @@ class ReArcDataset(ProceduralDataset): "question": input_prompt, "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "input": task["input"], "output": task["output"], "task_id": task_id, @@ -178,4 +182,4 @@ class ReArcCurriculum(BaseCurriculum): ) -register_dataset("rearc", ReArcDataset, ReArcConfig, ReArcCurriculum) +register_dataset(DATASET_NAME, ReArcDataset, ReArcConfig, ReArcCurriculum) diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 42dbcf02..036dea3d 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -5,6 +5,8 @@ from typing import Any, Literal, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "basic_arithmetic" + @dataclass class BasicArithmeticDatasetConfig: @@ -95,6 +97,8 @@ class BasicArithmeticDataset(ProceduralDataset): "question": question, "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "expression": expression, "num_terms": num_terms, "num_digits": num_digits, @@ -260,4 +264,4 @@ class BasicArithmeticCurriculum(BaseCurriculum): # Register the dataset -register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum) +register_dataset(DATASET_NAME, BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum) diff --git a/reasoning_gym/arithmetic/bitwise_arithmetic.py b/reasoning_gym/arithmetic/bitwise_arithmetic.py index 7b181ecd..cee8087a 100644 --- a/reasoning_gym/arithmetic/bitwise_arithmetic.py +++ b/reasoning_gym/arithmetic/bitwise_arithmetic.py @@ -5,6 +5,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "bitwise_arithmetic" + @dataclass class BitwiseArithmeticConfig: @@ -155,7 +157,12 @@ class BitwiseArithmeticDataset(ProceduralDataset): return { "question": problem_str, "answer": answer, - "metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}}, + "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, + "problem": problem, + "difficulty": {"difficulty": self.config.difficulty}, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -193,4 +200,4 @@ class BitwiseArithmeticCurriculum(BaseCurriculum): # Register the dataset with the factory. -register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum) +register_dataset(DATASET_NAME, BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum) diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index fcd1a7ce..10d56630 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -9,6 +9,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "calendar_arithmetic" + class Weekday(Enum): MONDAY = auto() @@ -126,6 +128,8 @@ class CalendarArithmeticDataset(ProceduralDataset): rng = random.Random(self.seed + idx) task = rng.choice(self.tasks) question, answer, metadata = task(rng) + metadata["source_dataset"] = DATASET_NAME + metadata["source_index"] = idx metadata["difficulty"] = { "task_complexity": self.tasks.index(task), "date_range": self.config.offset_upper_bound, @@ -523,6 +527,4 @@ class CalendarArithmeticCurriculum(BaseCurriculum): ) -register_dataset( - "calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum -) +register_dataset(DATASET_NAME, CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 8b0a3a19..d4778b7a 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -7,6 +7,8 @@ from reasoning_gym import utils from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "chain_sum" + @dataclass class ChainSumConfig: @@ -64,6 +66,8 @@ class ChainSumDataset(ProceduralDataset): "question": f"State the final answer to the following arithmetic problem: {expression} =", "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num_terms": num_terms, "num_digits": num_digits, "expression": expression, @@ -143,4 +147,4 @@ class ChainSumCurriculum(BaseCurriculum): # Register the dataset -register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum) +register_dataset(DATASET_NAME, ChainSumDataset, ChainSumConfig, ChainSumCurriculum) diff --git a/reasoning_gym/arithmetic/count_bits.py b/reasoning_gym/arithmetic/count_bits.py index a3557f3c..0934e6a6 100644 --- a/reasoning_gym/arithmetic/count_bits.py +++ b/reasoning_gym/arithmetic/count_bits.py @@ -9,6 +9,8 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?""" +DATASET_NAME = "count_bits" + @dataclass class CountBitsConfig: @@ -43,6 +45,8 @@ class CountBitsDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(number=number), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "number": number, "solution": answer, "binary": binary, @@ -70,4 +74,4 @@ class CountBitsCurriculum(BaseCurriculum): ) -register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum) +register_dataset(DATASET_NAME, CountBitsDataset, CountBitsConfig, CountBitsCurriculum) diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index 34ce306c..a7338221 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -7,6 +7,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "decimal_arithmetic" + @dataclass class DecimalArithmeticConfig: @@ -189,6 +191,8 @@ class DecimalArithmeticDataset(ProceduralDataset): "question": problem_str, "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "decimal_places": decimal_places, "num_terms": terms, "difficulty": { @@ -249,4 +253,4 @@ class DecimalArithmeticCurriculum(BaseCurriculum): # Register the dataset with the factory. -register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum) +register_dataset(DATASET_NAME, DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum) diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py index 057203ce..dbbe14b7 100644 --- a/reasoning_gym/arithmetic/decimal_chain_sum.py +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -6,6 +6,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "decimal_chain_sum" + @dataclass class DecimalChainSumConfig: @@ -66,6 +68,8 @@ class DecimalChainSumDataset(ProceduralDataset): "question": f"State the final answer to the following arithmetic problem: {expression} =", "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num_terms": num_terms, "num_digits": num_digits, "expression": expression, @@ -195,4 +199,4 @@ class DecimalChainSumCurriculum(BaseCurriculum): ) -register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum) +register_dataset(DATASET_NAME, DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum) diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index 8bf19981..521a4205 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -7,6 +7,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "dice" + def compute_probability(dice, target): """ @@ -124,6 +126,8 @@ class DiceDataset(ProceduralDataset): "question": puzzle_str, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "difficulty": { "num_dice": self.config.num_dice, @@ -174,4 +178,4 @@ class DiceCurriculum(BaseCurriculum): ) -register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum) +register_dataset(DATASET_NAME, DiceDataset, DiceConfig, DiceCurriculum) diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 2176cfef..75a19966 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -11,6 +11,8 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer." +DATASET_NAME = "fraction_simplification" + @dataclass class FractionSimplificationConfig: @@ -114,6 +116,8 @@ class FractionSimplificationDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(question_fraction=question_fraction), "answer": answer_fraction, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "numerator": num, "denominator": den, "simplified_numerator": simple_num, @@ -184,7 +188,7 @@ class FractionSimplificationCurriculum(BaseCurriculum): register_dataset( - "fraction_simplification", + DATASET_NAME, FractionSimplificationDataset, FractionSimplificationConfig, FractionSimplificationCurriculum, diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index 33b1e061..b7bf93e1 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -9,6 +9,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "gcd" + @dataclass class GCDConfig: @@ -62,6 +64,8 @@ class GCDDataset(ProceduralDataset): "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.", "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "numbers": numbers, "result": result, "num_terms": num_terms, @@ -96,4 +100,4 @@ class GCDCurriculum(BaseCurriculum): ) -register_dataset("gcd", GCDDataset, GCDConfig) +register_dataset(DATASET_NAME, GCDDataset, GCDConfig) diff --git a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py index b99ef3a0..962f5f52 100644 --- a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py +++ b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py @@ -7,6 +7,8 @@ from typing import Any, Callable, Optional from reasoning_gym.factory import ProceduralDataset, register_dataset +DATASET_NAME = "gsm_symbolic" + tasks_ok = [ 0, 1, @@ -151,6 +153,8 @@ class GSMSymbolicDataset(ProceduralDataset): generator = self.generators[generator_idx] example = generator(rng, self.config.difficulty) example["question"] += " Give the result as your final answer. Do not include units." + example["metadata"]["source_dataset"] = DATASET_NAME + example["metadata"]["source_index"] = idx return example def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -174,4 +178,4 @@ class GSMSymbolicDataset(ProceduralDataset): return reward -register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig) +register_dataset(DATASET_NAME, GSMSymbolicDataset, GSMSymbolicDatasetConfig) diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index 7f121a31..22a25802 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -9,6 +9,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "lcm" + @dataclass class LCMConfig: @@ -64,6 +66,8 @@ class LCMDataset(ProceduralDataset): "question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}", "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "numbers": numbers, "result": result, "difficulty": { @@ -98,4 +102,4 @@ class LCMCurriculum(BaseCurriculum): ) -register_dataset("lcm", LCMDataset, LCMConfig, LCMCurriculum) +register_dataset(DATASET_NAME, LCMDataset, LCMConfig, LCMCurriculum) diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 91b65251..4d9f7c5a 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -60,6 +60,8 @@ QUESTION_TEMPLATE = """Your task is to count how many legs there are in total wh Now, how many legs are there in total if you have {animals}? """ +DATASET_NAME = "leg_counting" + @dataclass class LegCountingConfig: @@ -118,6 +120,8 @@ class LegCountingDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)), "answer": str(total_legs), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "animals": animals, "num_animals": len(animals), "total_legs": total_legs, @@ -152,4 +156,4 @@ class LegCountingCurriculum(BaseCurriculum): ) -register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum) +register_dataset(DATASET_NAME, LegCountingDataset, LegCountingConfig, LegCountingCurriculum) diff --git a/reasoning_gym/arithmetic/number_format.py b/reasoning_gym/arithmetic/number_format.py index 955be7a8..64dda38a 100644 --- a/reasoning_gym/arithmetic/number_format.py +++ b/reasoning_gym/arithmetic/number_format.py @@ -14,6 +14,8 @@ Your output should be only the number of interest. Now, pick the {size} number of the following candidates: {numbers} """ +DATASET_NAME = "number_format" + @dataclass class NumberFormatConfig: @@ -94,6 +96,8 @@ class NumberFormatDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(numbers=" ".join(formatted_candidates), size=size), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "candidates": candidates, "solution": answer, "formatted_candidates": formatted_candidates, @@ -138,4 +142,4 @@ class NumberFormatCurriculum(BaseCurriculum): ) -register_dataset("number_format", NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum) +register_dataset(DATASET_NAME, NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum) diff --git a/reasoning_gym/arithmetic/power_function.py b/reasoning_gym/arithmetic/power_function.py index 931a7cbf..494fec94 100644 --- a/reasoning_gym/arithmetic/power_function.py +++ b/reasoning_gym/arithmetic/power_function.py @@ -15,6 +15,8 @@ Compute {base}^{exponent}. Return your final answer correct to 3 significant fig Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4). """ +DATASET_NAME = "power_function" + @dataclass class PowerFunctionConfig: @@ -74,6 +76,8 @@ class PowerFunctionDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(base=base, exponent=exponent), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "base": base, "exponent": exponent, "solution": answer, @@ -97,4 +101,4 @@ class PowerFunctionCurriculum(BaseCurriculum): ) -register_dataset("power_function", PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum) +register_dataset(DATASET_NAME, PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum) diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index b14d6b1b..ef56fe07 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -8,6 +8,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "prime_factorization" + @dataclass class PrimeFactorizationConfig: @@ -84,6 +86,8 @@ class PrimeFactorizationDataset(ProceduralDataset): ), "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "number": number, "factors": factors, "difficulty": { @@ -110,6 +114,4 @@ class PrimeFactorizationCurriculum(BaseCurriculum): ) -register_dataset( - "prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum -) +register_dataset(DATASET_NAME, PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index b9d38961..4a36ec20 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -7,6 +7,8 @@ from reasoning_gym import utils from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "products" + @dataclass class ProductsConfig: @@ -66,6 +68,8 @@ class ProductsDataset(ProceduralDataset): "question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.", "answer": str(result), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "expression": expression, "num_terms": num_terms, "num_digits": num_digits, @@ -135,4 +139,4 @@ class ProductsCurriculum(BaseCurriculum): # Register the dataset -register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum) +register_dataset(DATASET_NAME, ProductsDataset, ProductsConfig, ProductsCurriculum) diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index 42a6f343..df22c6ff 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -9,6 +9,8 @@ from dateutil import parser from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "time_intervals" + @dataclass class TimeIntervalsConfig: @@ -134,6 +136,8 @@ class TimeIntervalsDataset(ProceduralDataset): "question": question, "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "task_type": task_type, "start_time": start_dt, "end_time": end_dt, @@ -346,4 +350,4 @@ class TimeIntervalsCurriculum(BaseCurriculum): # Register the dataset -register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum) +register_dataset(DATASET_NAME, TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum) diff --git a/reasoning_gym/coaching/coach.py b/reasoning_gym/coaching/coach.py index 5142e666..f1bf39b9 100644 --- a/reasoning_gym/coaching/coach.py +++ b/reasoning_gym/coaching/coach.py @@ -147,11 +147,7 @@ class ScoreBoard: placed first in the tuple as ("source", dataset) and ("idx", index). """ # Start with empty list - key_items = [] - - # Add source info first if present - if "source_dataset" in metadata and "source_index" in metadata: - key_items.extend([("source", metadata["source_dataset"]), ("idx", metadata["source_index"])]) + key_items = [("source", metadata["source_dataset"]), ("idx", metadata["source_index"])] # Add difficulty parameters or other metadata if "difficulty" in metadata: diff --git a/reasoning_gym/code/bf.py b/reasoning_gym/code/bf.py index 0599325a..5ada391e 100644 --- a/reasoning_gym/code/bf.py +++ b/reasoning_gym/code/bf.py @@ -9,6 +9,8 @@ from ..data.wordle_words import wordle_words from ..factory import ProceduralDataset, register_dataset from .contrib.bfit.Compiler import Compiler, Minify +DATASET_NAME = "bf" + @dataclass class BFConfig: @@ -53,6 +55,8 @@ class BFDataset(ProceduralDataset): "question": rng.choice(self._prompt_templates).format(bf_program=bf_program), "answer": result, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "bfit_code": bfit_code, "bf_program": bf_program, "difficulty": {"difficulty": self.config.difficulty}, @@ -160,4 +164,4 @@ class BFCurriculum(BaseCurriculum): # Register the dataset -register_dataset("bf", BFDataset, BFConfig, BFCurriculum) +register_dataset(DATASET_NAME, BFDataset, BFConfig, BFCurriculum) diff --git a/reasoning_gym/code/codeio.py b/reasoning_gym/code/codeio.py index f71555f6..d8c36d51 100644 --- a/reasoning_gym/code/codeio.py +++ b/reasoning_gym/code/codeio.py @@ -49,6 +49,8 @@ Tip: Here is a reference code snippet for this question. You can refer to this c {3} """ +DATASET_NAME = "codeio" + @dataclass class CodeIOConfig: @@ -117,7 +119,12 @@ class CodeIODataset(ProceduralDataset): return { "question": question, "answer": solution, - "metadata": {"input_data": input_data, "output_data": output_data}, + "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, + "input_data": input_data, + "output_data": output_data, + }, } def _json_to_tree(self, data, label="root"): @@ -231,4 +238,4 @@ class CodeIODataset(ProceduralDataset): # Register the dataset -register_dataset("codeio", CodeIODataset, CodeIOConfig) +register_dataset(DATASET_NAME, CodeIODataset, CodeIOConfig) diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 1b1e486c..335c4b7e 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -35,6 +35,9 @@ class Side(StrEnum): BOTTOM = "bottom" +DATASET_NAME = "color_cube_rotation" + + @dataclass class Cube: """Represents a cube with colored sides""" @@ -137,6 +140,8 @@ class ColorCubeRotationDataset(ProceduralDataset): "question": story, "answer": cube.colors[target_side], "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "initial_state": {k.value: v.value for k, v in initial_state.items()}, "rotations": [r.value for r in rotations], "target_side": target_side.value, @@ -225,4 +230,4 @@ class ColorCubeRotationCurriculum(BaseCurriculum): ) -register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig, ColorCubeRotationCurriculum) +register_dataset(DATASET_NAME, ColorCubeRotationDataset, ColorCubeRotationConfig, ColorCubeRotationCurriculum) diff --git a/reasoning_gym/cognition/figlet_fonts.py b/reasoning_gym/cognition/figlet_fonts.py index 715fc79e..e54f5c56 100644 --- a/reasoning_gym/cognition/figlet_fonts.py +++ b/reasoning_gym/cognition/figlet_fonts.py @@ -120,6 +120,8 @@ BAD_FONTS = [ ALL_FONTS = pyfiglet.FigletFont.getFonts() OK_FONTS = list(filter(lambda x: x not in BAD_FONTS, ALL_FONTS)) +DATASET_NAME = "figlet_font" + @dataclass class FigletFontConfig: @@ -186,6 +188,8 @@ class FigletFontDataset(ProceduralDataset): "question": rng.choice(self._prompt_templates).format(figlet_render=figlet_render), "answer": word, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "font": chosen_font, "space_letters": self.config.space_letters, "difficulty": { @@ -248,4 +252,4 @@ class FigletFontCurriculum(BaseCurriculum): # Register the dataset -register_dataset("figlet_font", FigletFontDataset, FigletFontConfig, FigletFontCurriculum) +register_dataset(DATASET_NAME, FigletFontDataset, FigletFontConfig, FigletFontCurriculum) diff --git a/reasoning_gym/cognition/modulo_grid.py b/reasoning_gym/cognition/modulo_grid.py index 88debfab..a241dc78 100644 --- a/reasoning_gym/cognition/modulo_grid.py +++ b/reasoning_gym/cognition/modulo_grid.py @@ -6,6 +6,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "modulo_grid" + @dataclass class ModuloGridConfig: @@ -136,6 +138,8 @@ class ModuloGridDataset(ProceduralDataset): "question": question, "answer": flatten_grid(grid), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "divisor": divisor, "target": target, "operation": operation, @@ -190,4 +194,4 @@ class ModuloGridCurriculum(BaseCurriculum): # Register the dataset -register_dataset("modulo_grid", ModuloGridDataset, ModuloGridConfig, ModuloGridCurriculum) +register_dataset(DATASET_NAME, ModuloGridDataset, ModuloGridConfig, ModuloGridCurriculum) diff --git a/reasoning_gym/cognition/needle_haystack.py b/reasoning_gym/cognition/needle_haystack.py index 2eb8ee55..781fbba8 100644 --- a/reasoning_gym/cognition/needle_haystack.py +++ b/reasoning_gym/cognition/needle_haystack.py @@ -5,6 +5,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "needle_haystack" + @dataclass class NeedleHaystackConfig: @@ -104,6 +106,8 @@ class NeedleHaystackDataset(ProceduralDataset): "question": full_text, "answer": stack["needle"][0], "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "question": question, "num_statements": num_statements, "difficulty": { @@ -153,4 +157,4 @@ class NeedleHaystackCurriculum(BaseCurriculum): # Register the dataset -register_dataset("needle_haystack", NeedleHaystackDataset, NeedleHaystackConfig, NeedleHaystackCurriculum) +register_dataset(DATASET_NAME, NeedleHaystackDataset, NeedleHaystackConfig, NeedleHaystackCurriculum) diff --git a/reasoning_gym/cognition/number_sequences.py b/reasoning_gym/cognition/number_sequences.py index 6b9b4619..0bdc365c 100644 --- a/reasoning_gym/cognition/number_sequences.py +++ b/reasoning_gym/cognition/number_sequences.py @@ -6,6 +6,8 @@ from typing import Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "number_sequence" + class Operation(StrEnum): """Basic mathematical operations that can be composed""" @@ -196,6 +198,8 @@ class NumberSequenceDataset(ProceduralDataset): "question": ", ".join(map(str, visible_terms)) + ", ?", "answer": str(sequence[-1]), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "rule": rule.to_string(), "complexity": complexity, "sequence": sequence, @@ -220,4 +224,4 @@ class NumberSequenceCurriculum(BaseCurriculum): ) -register_dataset("number_sequence", NumberSequenceDataset, NumberSequenceConfig, NumberSequenceCurriculum) +register_dataset(DATASET_NAME, NumberSequenceDataset, NumberSequenceConfig, NumberSequenceCurriculum) diff --git a/reasoning_gym/cognition/rectangle_count.py b/reasoning_gym/cognition/rectangle_count.py index 2fb02dca..1b5fad37 100644 --- a/reasoning_gym/cognition/rectangle_count.py +++ b/reasoning_gym/cognition/rectangle_count.py @@ -15,6 +15,8 @@ Now, it's your turn. How many rectangles do you see in the grid below? {puzzle} """ +DATASET_NAME = "rectangle_count" + def draw_rectangles_with_overlap(n, width, height, rng): # Create a grid that holds a count of how many times a cell is drawn. @@ -118,6 +120,8 @@ class RectangleCountDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(puzzle=puzzle), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "solution": answer, "num_rectangles": target, @@ -161,4 +165,4 @@ class RectangleCountCurriculum(BaseCurriculum): ) -register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig, RectangleCountCurriculum) +register_dataset(DATASET_NAME, RectangleCountDataset, RectangleCountConfig, RectangleCountCurriculum) diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py index 25b236f6..f6869fd4 100644 --- a/reasoning_gym/cognition/rubiks_cube.py +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -9,6 +9,8 @@ from magiccube.solver.basic.basic_solver import BasicSolver from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "rubiks_cube" + @dataclass class RubiksCubeConfig: @@ -105,6 +107,8 @@ class RubiksCubeDataset(ProceduralDataset): ), "answer": None, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "cube_size": self.config.cube_size, "scramble_steps": num_steps, "scramble_moves": " ".join([str(move) for move in scramble_moves]), @@ -188,4 +192,4 @@ class RubiksCubeCurriculum(BaseCurriculum): # Register the dataset -register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig, RubiksCubeCurriculum) +register_dataset(DATASET_NAME, RubiksCubeDataset, RubiksCubeConfig, RubiksCubeCurriculum) diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index 05700151..ff6a14ce 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -126,10 +126,6 @@ class CompositeDataset(ProceduralDataset): # Get item from selected dataset item = dataset[idx] - # Add source dataset info to metadata - item["metadata"]["source_dataset"] = dataset_name - item["metadata"]["source_index"] = idx - # Add version info if tracking enabled if self.version_manager is not None: version_id = self.dataset_versions[dataset_name] diff --git a/reasoning_gym/games/boxnet.py b/reasoning_gym/games/boxnet.py index 312cddca..7ef791d7 100644 --- a/reasoning_gym/games/boxnet.py +++ b/reasoning_gym/games/boxnet.py @@ -36,6 +36,8 @@ For example: Include an agent in the action plan only if it has a task to perform next. """ +DATASET_NAME = "boxnet" + def action_from_response(pg_dict_input, original_response_dict_list): pg_dict_current = copy.deepcopy(pg_dict_input) @@ -126,6 +128,8 @@ class BoxnetDataset(ProceduralDataset): "question": question, "answer": None, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "row_num": row_num, "column_num": column_num, "initial_state": pg_dict, @@ -248,4 +252,4 @@ class BoxnetCurriculum(BaseCurriculum): ) -register_dataset("boxnet", BoxnetDataset, BoxnetConfig, BoxnetCurriculum) +register_dataset(DATASET_NAME, BoxnetDataset, BoxnetConfig, BoxnetCurriculum) diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 270476be..ecafbc8f 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -20,6 +20,9 @@ Final answer format instructions: """ +DATASET_NAME = "countdown" + + @dataclass class CountdownConfig: """Configuration for Countdown Number Game task generation""" @@ -85,6 +88,8 @@ class CountdownDataset(ProceduralDataset): "question": QUESTION_FORMAT_TEMPLATE.format(question=question), "answer": expression, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "numbers": numbers, "target": target, "expression": expression, @@ -195,4 +200,4 @@ class CountdownDataset(ProceduralDataset): # Register the dataset -register_dataset("countdown", CountdownDataset, CountdownConfig) +register_dataset(DATASET_NAME, CountdownDataset, CountdownConfig) diff --git a/reasoning_gym/games/emoji_mystery.py b/reasoning_gym/games/emoji_mystery.py index d41b7156..0c004bb6 100644 --- a/reasoning_gym/games/emoji_mystery.py +++ b/reasoning_gym/games/emoji_mystery.py @@ -152,6 +152,8 @@ Decode the following sentence from the emoji: {sentence} Return the secret sentence as your final answer. """ +DATASET_NAME = "emoji_mystery" + @dataclass class EmojiMysteryConfig: @@ -193,6 +195,8 @@ class EmojiMysteryDataset(ProceduralDataset): "question": question, "answer": secret_sentence, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "emoji": secret_emoji, "num_words_in_sentence": len(re.findall(r"\b\w+\b", secret_sentence)), "difficulty": { @@ -259,4 +263,4 @@ class EmojiMysteryCurriculum(BaseCurriculum): ) -register_dataset("emoji_mystery", EmojiMysteryDataset, EmojiMysteryConfig, EmojiMysteryCurriculum) +register_dataset(DATASET_NAME, EmojiMysteryDataset, EmojiMysteryConfig, EmojiMysteryCurriculum) diff --git a/reasoning_gym/games/futoshiki.py b/reasoning_gym/games/futoshiki.py index 9b270213..b027dfb2 100644 --- a/reasoning_gym/games/futoshiki.py +++ b/reasoning_gym/games/futoshiki.py @@ -9,6 +9,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "futoshiki" + @dataclass class FutoshikiConfig: @@ -81,6 +83,8 @@ class FutoshikiDataset(ProceduralDataset): "question": question, "answer": solution_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "constraints": constraints, "solution": solution, @@ -686,4 +690,4 @@ class FutoshikiCurriculum(BaseCurriculum): ) -register_dataset("futoshiki", FutoshikiDataset, FutoshikiConfig) +register_dataset(DATASET_NAME, FutoshikiDataset, FutoshikiConfig) diff --git a/reasoning_gym/games/knight_swap.py b/reasoning_gym/games/knight_swap.py index 2a54289d..d7e4c790 100644 --- a/reasoning_gym/games/knight_swap.py +++ b/reasoning_gym/games/knight_swap.py @@ -35,6 +35,8 @@ Answer Format: Example: ["w,A1,B3"] means white knight moves A1→B3 """ +DATASET_NAME = "knight_swap" + @dataclass class KnightSwapConfig: @@ -286,6 +288,8 @@ class KnightSwapDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(board=board_str, start_turn=start_turn), "answer": solution_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "board": board_copy, "pieces": pieces, "start_turn": start_turn, @@ -392,4 +396,4 @@ class KnightSwapDataset(ProceduralDataset): return 0.0 -register_dataset("knight_swap", KnightSwapDataset, KnightSwapConfig) +register_dataset(DATASET_NAME, KnightSwapDataset, KnightSwapConfig) diff --git a/reasoning_gym/games/mahjong.py b/reasoning_gym/games/mahjong.py index 1f380167..f8ff47c1 100644 --- a/reasoning_gym/games/mahjong.py +++ b/reasoning_gym/games/mahjong.py @@ -26,6 +26,8 @@ Now, given the initial cards {cards}, what is the result at the end of performin {operations} """ +DATASET_NAME = "mahjong_puzzle" + @dataclass class MahjongPuzzleConfig: @@ -120,6 +122,8 @@ class MahjongPuzzleDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(cards=cards, operations=operations), "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "rounds": rounds, "solution": answer, "difficulty": { @@ -145,4 +149,4 @@ class MahjongPuzzleCurriculum(BaseCurriculum): ) -register_dataset("mahjong_puzzle", MahjongPuzzleDataset, MahjongPuzzleConfig, MahjongPuzzleCurriculum) +register_dataset(DATASET_NAME, MahjongPuzzleDataset, MahjongPuzzleConfig, MahjongPuzzleCurriculum) diff --git a/reasoning_gym/games/maze.py b/reasoning_gym/games/maze.py index a6305cc3..ef70aa91 100644 --- a/reasoning_gym/games/maze.py +++ b/reasoning_gym/games/maze.py @@ -6,6 +6,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "maze" + @dataclass class MazeConfig: @@ -104,6 +106,8 @@ class MazeDataset(ProceduralDataset): "question": question_str, "answer": str(dist), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "grid_size": size, "grid": ["".join(row) for row in maze_grid], "shortest_path_length": dist, @@ -214,4 +218,4 @@ class MazeCurriculum(BaseCurriculum): ) -register_dataset("maze", MazeDataset, MazeConfig, MazeCurriculum) +register_dataset(DATASET_NAME, MazeDataset, MazeConfig, MazeCurriculum) diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index 780795b7..8a0b3901 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -8,6 +8,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "mini_sudoku" + @dataclass class MiniSudokuConfig: @@ -193,6 +195,8 @@ class MiniSudokuDataset(ProceduralDataset): "question": question, "answer": solution_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "solution": solved_board, "num_empty": num_empty, @@ -257,4 +261,4 @@ class MiniSudokuCurriculum(BaseCurriculum): ) -register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig, MiniSudokuCurriculum) +register_dataset(DATASET_NAME, MiniSudokuDataset, MiniSudokuConfig, MiniSudokuCurriculum) diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py index cef4d1a1..eb958376 100644 --- a/reasoning_gym/games/n_queens.py +++ b/reasoning_gym/games/n_queens.py @@ -27,6 +27,8 @@ Given the below board of size {n} x {n} your job is to place {num_removed} queen {puzzle} """ +DATASET_NAME = "n_queens" + @dataclass class NQueensConfig: @@ -131,6 +133,8 @@ class NQueensDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(puzzle=puzzle_str, n=len(puzzle), num_removed=num_removed), "answer": rng.choice(valid_solutions_str), # choose arbitary answer (e.g. for SFT) "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "solutions": valid_solutions, "num_removed": num_removed, @@ -177,4 +181,4 @@ class NQueensCurriculum(BaseCurriculum): ) -register_dataset("n_queens", NQueensDataset, NQueensConfig, NQueensCurriculum) +register_dataset(DATASET_NAME, NQueensDataset, NQueensConfig, NQueensCurriculum) diff --git a/reasoning_gym/games/puzzle24.py b/reasoning_gym/games/puzzle24.py index 7a8a699b..a0b86d41 100644 --- a/reasoning_gym/games/puzzle24.py +++ b/reasoning_gym/games/puzzle24.py @@ -17,6 +17,8 @@ Final answer format instructions: 4. Use '/' for division. """ +DATASET_NAME = "puzzle24" + @dataclass class Puzzle24Config: @@ -101,6 +103,8 @@ class Puzzle24Dataset(ProceduralDataset): "question": question, "answer": expr_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "numbers": numbers, "expression": expr, }, @@ -127,4 +131,4 @@ class Puzzle24Dataset(ProceduralDataset): return reward -register_dataset("puzzle24", Puzzle24Dataset, Puzzle24Config) +register_dataset(DATASET_NAME, Puzzle24Dataset, Puzzle24Config) diff --git a/reasoning_gym/games/rush_hour.py b/reasoning_gym/games/rush_hour.py index ba3e49ff..1c84abc3 100644 --- a/reasoning_gym/games/rush_hour.py +++ b/reasoning_gym/games/rush_hour.py @@ -42,6 +42,9 @@ H = 1 # horizontal stride V = BOARD_SIZE # vertical stride +DATASET_NAME = "rush_hour" + + # board boundary limits def create_row_masks() -> list[int]: row_masks: list[int] = [] @@ -159,6 +162,8 @@ class RushHourDataset(ProceduralDataset): "question": f"{instructions}\n\nBoard:\n{board_display}", "answer": None, # Multiple valid solutions exist "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "board_config": board_config, "min_moves": min_moves, "difficulty": { @@ -387,4 +392,4 @@ class RushHourCurriculum(BaseCurriculum): # Register the dataset -register_dataset("rush_hour", RushHourDataset, RushHourConfig, RushHourCurriculum) +register_dataset(DATASET_NAME, RushHourDataset, RushHourConfig, RushHourCurriculum) diff --git a/reasoning_gym/games/sokoban.py b/reasoning_gym/games/sokoban.py index 857139f8..0d8da982 100644 --- a/reasoning_gym/games/sokoban.py +++ b/reasoning_gym/games/sokoban.py @@ -7,6 +7,8 @@ import numpy as np from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "sokoban" + @dataclass class SokobanConfig: @@ -98,6 +100,8 @@ Here is your puzzle: "width": puzzle_data["width"], "height": puzzle_data["height"], "difficulty": { + "source_dataset": DATASET_NAME, + "source_index": idx, "width": (self.config.min_w, self.config.max_w), "height": (self.config.min_h, self.config.max_h), }, @@ -160,4 +164,4 @@ class SokobanCurriculum(BaseCurriculum): ) -register_dataset("sokoban", SokobanDataset, SokobanConfig, SokobanCurriculum) +register_dataset(DATASET_NAME, SokobanDataset, SokobanConfig, SokobanCurriculum) diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index c3351a91..0aa7c6e3 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -8,6 +8,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "sudoku" + @dataclass class SudokuConfig: @@ -212,6 +214,8 @@ class SudokuDataset(ProceduralDataset): "question": question, "answer": solution_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "puzzle": puzzle, "solution": solved_board, "num_empty": num_empty, @@ -276,4 +280,4 @@ class SudokuCurriculum(BaseCurriculum): ) -register_dataset("sudoku", SudokuDataset, SudokuConfig, SudokuCurriculum) +register_dataset(DATASET_NAME, SudokuDataset, SudokuConfig, SudokuCurriculum) diff --git a/reasoning_gym/games/tower_of_hanoi.py b/reasoning_gym/games/tower_of_hanoi.py index f89f7f7c..cd458fe0 100644 --- a/reasoning_gym/games/tower_of_hanoi.py +++ b/reasoning_gym/games/tower_of_hanoi.py @@ -23,6 +23,8 @@ Formatting guidelines: - Do not include any other text or formatting. """ +DATASET_NAME = "tower_of_hanoi" + @dataclass class HanoiConfig: @@ -269,6 +271,8 @@ class HanoiDataset(ProceduralDataset): ), "answer": "\n".join(solution), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "num_disks": num_disks, "num_pegs": num_pegs, "start_peg": start_peg, @@ -452,4 +456,4 @@ class HanoiCurriculum(BaseCurriculum): # Register the dataset -register_dataset("tower_of_hanoi", HanoiDataset, HanoiConfig, HanoiCurriculum) +register_dataset(DATASET_NAME, HanoiDataset, HanoiConfig, HanoiCurriculum) diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py index ba0475ee..cac14f37 100644 --- a/reasoning_gym/games/tsumego.py +++ b/reasoning_gym/games/tsumego.py @@ -27,6 +27,8 @@ from ..factory import ProceduralDataset, register_dataset # Added constant to avoid repetition of adjacent directions DIRECTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] +DATASET_NAME = "tsumego" + @dataclass class TsumegoConfig: @@ -271,6 +273,8 @@ class TsumegoDataset(ProceduralDataset): ), "answer": solution_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "board": board, "board_size": size, "difficulty": { @@ -312,4 +316,4 @@ class TsumegoCurriculum(BaseCurriculum): # Register the dataset -register_dataset("tsumego", TsumegoDataset, TsumegoConfig, TsumegoCurriculum) +register_dataset(DATASET_NAME, TsumegoDataset, TsumegoConfig, TsumegoCurriculum) diff --git a/reasoning_gym/geometry/advanced_geometry.py b/reasoning_gym/geometry/advanced_geometry.py index 257cd359..9d661a23 100644 --- a/reasoning_gym/geometry/advanced_geometry.py +++ b/reasoning_gym/geometry/advanced_geometry.py @@ -9,6 +9,8 @@ from sympy.geometry import Point from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "advanced_geometry" + @dataclass class AdvancedGeometryConfig: @@ -87,8 +89,9 @@ class AdvancedGeometryDataset(ProceduralDataset): else: raise ValueError(f"Unknown task_type: {task_type}") + metadata["source_dataset"] = DATASET_NAME + metadata["source_index"] = idx metadata["task_type"] = task_type - metadata["difficulty"] = { "min_coord": self.config.min_coord, "max_coord": self.config.max_coord, @@ -296,4 +299,4 @@ class AdvancedGeometryCurriculum(BaseCurriculum): # Register the dataset -register_dataset("advanced_geometry", AdvancedGeometryDataset, AdvancedGeometryConfig, AdvancedGeometryCurriculum) +register_dataset(DATASET_NAME, AdvancedGeometryDataset, AdvancedGeometryConfig, AdvancedGeometryCurriculum) diff --git a/reasoning_gym/geometry/simple_geometry.py b/reasoning_gym/geometry/simple_geometry.py index 0f151d1f..4d53cc32 100644 --- a/reasoning_gym/geometry/simple_geometry.py +++ b/reasoning_gym/geometry/simple_geometry.py @@ -5,6 +5,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "simple_geometry" + @dataclass class SimpleGeometryConfig: @@ -109,6 +111,8 @@ class SimpleGeometryDataset(ProceduralDataset): "question": prompt, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "n_sides": n_sides, "known_angles": known_angles, "sum_of_known_angles": sum(known_angles), @@ -164,4 +168,4 @@ class SimpleGeometryCurriculum(BaseCurriculum): # Register the dataset so it can be accessed similarly to the others -register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig, SimpleGeometryCurriculum) +register_dataset(DATASET_NAME, SimpleGeometryDataset, SimpleGeometryConfig, SimpleGeometryCurriculum) diff --git a/reasoning_gym/graphs/course_schedule.py b/reasoning_gym/graphs/course_schedule.py index c18d89c0..e30af4b3 100644 --- a/reasoning_gym/graphs/course_schedule.py +++ b/reasoning_gym/graphs/course_schedule.py @@ -21,6 +21,8 @@ You are given the following list of prerequisites, where prerequisites[i] = (a_i Return True if you can finish all courses considering the prerequisites, or False otherwise. """ +DATASET_NAME = "course_schedule" + @dataclass class CourseScheduleConfig: @@ -132,6 +134,8 @@ class CourseScheduleDataset(ProceduralDataset): ), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "courses": courses, "prerequisites": prerequisites, "solution": answer, @@ -178,4 +182,4 @@ class CourseScheduleCurriculum(BaseCurriculum): ) -register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig, CourseScheduleCurriculum) +register_dataset(DATASET_NAME, CourseScheduleDataset, CourseScheduleConfig, CourseScheduleCurriculum) diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index a410a04b..5d52cb80 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -7,6 +7,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "family_relationships" + class Gender(StrEnum): MALE = "male" @@ -201,6 +203,8 @@ class FamilyRelationshipsDataset(ProceduralDataset): "question": f"{story}\n\n{question}", "answer": relationship.value, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "person1": person1.name, "person2": person2.name, "relationship": relationship.value, @@ -386,6 +390,4 @@ class FamilyRelationshipsCurriculum(BaseCurriculum): ) -register_dataset( - "family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum -) +register_dataset(DATASET_NAME, FamilyRelationshipsDataset, FamilyRelationshipsConfig, FamilyRelationshipsCurriculum) diff --git a/reasoning_gym/graphs/largest_island.py b/reasoning_gym/graphs/largest_island.py index d3077fba..beb36554 100644 --- a/reasoning_gym/graphs/largest_island.py +++ b/reasoning_gym/graphs/largest_island.py @@ -23,6 +23,8 @@ The area of an island is the number of cells with a value 1 in the island. Return the maximum area of an island in grid. If there is no island, return 0. """ +DATASET_NAME = "largest_island" + @dataclass class LargestIslandConfig: @@ -139,6 +141,8 @@ class LargestIslandDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str), "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "grid": grid, "solution": answer, "difficulty": { @@ -188,4 +192,4 @@ class LargestIslandCurriculum(BaseCurriculum): ) -register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig, LargestIslandCurriculum) +register_dataset(DATASET_NAME, LargestIslandDataset, LargestIslandConfig, LargestIslandCurriculum) diff --git a/reasoning_gym/graphs/quantum_lock.py b/reasoning_gym/graphs/quantum_lock.py index 5256502f..5aaa7546 100644 --- a/reasoning_gym/graphs/quantum_lock.py +++ b/reasoning_gym/graphs/quantum_lock.py @@ -7,6 +7,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "quantum_lock" + @dataclass class QuantumLockConfig: @@ -56,6 +58,8 @@ Buttons: "question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data), "answer": " → ".join(puzzle_data["solution"]), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "solution_path": puzzle_data["solution"], "target_value": puzzle_data["target_value"], "buttons": puzzle_data["buttons"], @@ -249,4 +253,4 @@ class QuantumLockCurriculum(BaseCurriculum): # Register the dataset -register_dataset("quantum_lock", QuantumLockDataset, QuantumLockConfig, QuantumLockCurriculum) +register_dataset(DATASET_NAME, QuantumLockDataset, QuantumLockConfig, QuantumLockCurriculum) diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index 93b91b6d..fbec984b 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -28,6 +28,8 @@ Now, find the length of the shortest path from * to # in the following grid: {grid} """ +DATASET_NAME = "shortest_path" + @dataclass class ShortestPathConfig: @@ -159,6 +161,8 @@ class ShortestPathDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(grid=matrix_str), "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "matrix": matrix, "solution": answer, "difficulty": { @@ -192,4 +196,4 @@ class ShortestPathCurriculum(BaseCurriculum): ) -register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig, ShortestPathCurriculum) +register_dataset(DATASET_NAME, ShortestPathDataset, ShortestPathConfig, ShortestPathCurriculum) diff --git a/reasoning_gym/induction/acre/acre.py b/reasoning_gym/induction/acre/acre.py index b6ef86a3..d2211a6b 100644 --- a/reasoning_gym/induction/acre/acre.py +++ b/reasoning_gym/induction/acre/acre.py @@ -12,6 +12,8 @@ from reasoning_gym.factory import ProceduralDataset, register_dataset from .blicket import config_control, dist_control, final_parse, serialize from .const import ALL_CONFIG_SIZE, ATTR_CONFIG_SIZE +DATASET_NAME = "acre" + # Create blicket questions @dataclass @@ -88,7 +90,14 @@ What is the detector light status?""" prompt_input = ", ".join(" ".join(x) for x in input["question"]["input"]) answer = input["question"]["output"] question = self.prompt_template.format(examples=formatted_examples, input=prompt_input) - return {"question": question, "answer": answer, "metadata": {}} + return { + "question": question, + "answer": answer, + "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, + }, + } -register_dataset("acre", ACREDataset, ACREDatasetConfig) +register_dataset(DATASET_NAME, ACREDataset, ACREDatasetConfig) diff --git a/reasoning_gym/induction/list_functions/list_functions.py b/reasoning_gym/induction/list_functions/list_functions.py index f91a683a..c1b0deda 100644 --- a/reasoning_gym/induction/list_functions/list_functions.py +++ b/reasoning_gym/induction/list_functions/list_functions.py @@ -6,6 +6,8 @@ from typing import Any, Callable, Optional from reasoning_gym.factory import ProceduralDataset, register_dataset +DATASET_NAME = "list_functions" + @dataclass class ListFunctionsDatasetConfig: @@ -75,7 +77,14 @@ Output: Output {index + 1}: {examples[key]} """ question = self.prompt_template.format(examples=formatted_examples, input=input) - return {"question": question, "answer": output, "metadata": {}} + return { + "question": question, + "answer": output, + "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, + }, + } -register_dataset("list_functions", ListFunctionsDataset, ListFunctionsDatasetConfig) +register_dataset(DATASET_NAME, ListFunctionsDataset, ListFunctionsDatasetConfig) diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 95e034e1..8cb95f64 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -7,6 +7,8 @@ from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "aiw" + class TaskType(StrEnum): """Defines the type of task for the Alice in Wonderland dataset.""" @@ -134,7 +136,7 @@ class AliceInWonderlandDataset(ProceduralDataset): ], } - def _get_aiw(self, rng: Random) -> dict: + def _get_aiw(self, rng: Random, idx: int) -> dict: """Generates a single Alice in Wonderland question. Args: @@ -194,6 +196,8 @@ class AliceInWonderlandDataset(ProceduralDataset): "question": question, "answer": str(answer), "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "task_type": task_type.value, "difficulty": { "task_type_weight": self.config.task_type_weights, @@ -204,7 +208,7 @@ class AliceInWonderlandDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: rng = Random(self.seed + idx) - return self._get_aiw(rng) + return self._get_aiw(rng, idx) class AliceInWonderlandCurriculum(BaseCurriculum): @@ -238,4 +242,4 @@ class AliceInWonderlandCurriculum(BaseCurriculum): ) -register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig, AliceInWonderlandCurriculum) +register_dataset(DATASET_NAME, AliceInWonderlandDataset, AliceInWonderlandConfig, AliceInWonderlandCurriculum) diff --git a/reasoning_gym/logic/circuit_logic.py b/reasoning_gym/logic/circuit_logic.py index 45729768..a77f1a6e 100644 --- a/reasoning_gym/logic/circuit_logic.py +++ b/reasoning_gym/logic/circuit_logic.py @@ -13,6 +13,8 @@ LDOWN = "┐" RUP = "└" RDOWN = "┌" +DATASET_NAME = "circuit_logic" + def _repeat(s: str, n: int) -> str: return s * n @@ -381,6 +383,8 @@ class CircuitLogicDataset(ProceduralDataset): "question": question_str, "answer": answer_str, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "expression": expression_for_display, "assignments": assignments, "term_strings": term_strings, @@ -429,4 +433,4 @@ class CircuitLogicCurriculum(BaseCurriculum): ) -register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig, CircuitLogicCurriculum) +register_dataset(DATASET_NAME, CircuitLogicDataset, CircuitLogicConfig, CircuitLogicCurriculum) diff --git a/reasoning_gym/logic/knights_knaves.py b/reasoning_gym/logic/knights_knaves.py index 0e96d069..fe4f503f 100644 --- a/reasoning_gym/logic/knights_knaves.py +++ b/reasoning_gym/logic/knights_knaves.py @@ -8,6 +8,8 @@ import numpy as np from reasoning_gym.factory import ProceduralDataset, register_dataset +DATASET_NAME = "knights_knaves" + COMMON_NAMES = [ "Emma", "Liam", @@ -428,9 +430,9 @@ class KnightsKnavesDataset(ProceduralDataset): - metadata: dict (additional problem details) """ rng = Random(self.seed + idx if self.seed is not None else None) - return self.__generate_problem(rng) + return self.__generate_problem(rng, idx) - def __generate_problem(self, rng: Random) -> dict[str, Any]: + def __generate_problem(self, rng: Random, idx: int) -> dict[str, Any]: """ Generate a single knights and knaves problem with a unique solution. """ @@ -454,6 +456,8 @@ class KnightsKnavesDataset(ProceduralDataset): question = formatted["quiz"] answer = formatted["solution_text"] metadata = { + "source_dataset": DATASET_NAME, + "source_index": idx, "statements": problem["statements"], "solution": problem["solution"], "names": formatted["names"], @@ -511,4 +515,4 @@ class KnightsKnavesDataset(ProceduralDataset): return 0.0 -register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig) +register_dataset(DATASET_NAME, KnightsKnavesDataset, KnightsKnavesConfig) diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index c67d3eb9..d3c6479c 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -9,6 +9,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "propositional_logic" + def parse_expr(expr: str): expr = expr.strip() @@ -216,6 +218,8 @@ class PropositionalLogicDataset(ProceduralDataset): "question": question, "answer": None, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "premises": [str(p) for p in premises], "variables": variables, "complexity": self._measure_complexity(conclusion), @@ -367,6 +371,4 @@ class PropositionalLogicCurriculum(BaseCurriculum): ) -register_dataset( - "propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig, PropositionalLogicCurriculum -) +register_dataset(DATASET_NAME, PropositionalLogicDataset, PropositionalLogicConfig, PropositionalLogicCurriculum) diff --git a/reasoning_gym/logic/self_reference.py b/reasoning_gym/logic/self_reference.py index 9568e388..d2c9d68c 100644 --- a/reasoning_gym/logic/self_reference.py +++ b/reasoning_gym/logic/self_reference.py @@ -5,6 +5,8 @@ from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "self_reference" + def is_prime(n): """Return True if n is a prime number, False otherwise.""" @@ -347,6 +349,8 @@ class SelfReferenceDataset(ProceduralDataset): "question": puzz_s, "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "difficulty": {"difficulty": difficulty}, }, } @@ -383,4 +387,4 @@ class SelfReferenceCurriculum(BaseCurriculum): ) -register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig, SelfReferenceCurriculum) +register_dataset(DATASET_NAME, SelfReferenceDataset, SelfReferenceConfig, SelfReferenceCurriculum) diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index 9546ad1d..ec1cff30 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -7,6 +7,8 @@ from typing import Optional from ..factory import ProceduralDataset, register_dataset +DATASET_NAME = "syllogism" + class Quantifier(StrEnum): ALL = "All" @@ -277,7 +279,7 @@ class SyllogismDataset(ProceduralDataset): return False - def _generate_syllogism(self, rng: Random) -> dict: + def _generate_syllogism(self, rng: Random, idx: int) -> dict: """Generate a single syllogism problem""" # Select three different terms terms = rng.sample(self.terms, 3) @@ -374,6 +376,8 @@ class SyllogismDataset(ProceduralDataset): "question": question, "answer": "Yes" if is_valid else "No", "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "premise1": premise1_text, "premise2": premise2_text, "selected_premise": selected_premise_num, @@ -437,7 +441,7 @@ class SyllogismDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single syllogism task""" rng = Random(self.seed + idx) - return self._generate_syllogism(rng) + return self._generate_syllogism(rng, idx) -register_dataset("syllogism", SyllogismDataset, SyllogismConfig) +register_dataset(DATASET_NAME, SyllogismDataset, SyllogismConfig) diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py index 143f52c7..909b24ce 100644 --- a/reasoning_gym/logic/zebra_puzzles.py +++ b/reasoning_gym/logic/zebra_puzzles.py @@ -6,6 +6,8 @@ from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset from .contrib.logic_puzzle.generate import generate_puzzle +DATASET_NAME = "zebra_puzzles" + @dataclass class ZebraConfig: @@ -51,6 +53,8 @@ class ZebraDataset(ProceduralDataset): "question": question, "answer": answer, "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, "difficulty": {"num_people": K, "num_characteristics": M}, }, } @@ -93,4 +97,4 @@ class ZebraCurriculum(BaseCurriculum): ) -register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig, ZebraCurriculum) +register_dataset(DATASET_NAME, ZebraDataset, ZebraConfig, ZebraCurriculum) diff --git a/tests/test_coaching.py b/tests/test_coaching.py index 01db2bc9..1741e87a 100644 --- a/tests/test_coaching.py +++ b/tests/test_coaching.py @@ -56,12 +56,7 @@ def test_coach_with_chain_sum(): # Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value)) for param in key: assert isinstance(param, tuple) - assert param[0] in ("num_terms", "num_digits") - assert ( - isinstance(param[1], int) - or (isinstance(param[1], tuple) and len(param[1]) == 2) - and all(isinstance(v, int) for v in param[1]) - ) + assert param[0] in ("source", "idx", "num_terms", "num_digits") # Test aggregation with last_n last_3 = coach.score_board.aggregate(last_n=3)