mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
include ranges rather than sampled values in difficulty metadata dicts (#387)
* update difficulty metadata for logic datasets * update difficulty metadata for graph datasets * update difficulty metadata for geometry datasets * update difficulty metadata for games datasets * update difficulty metadata for cognition datasets * update difficulty metadata for arithmetic datasets * update difficulty metadata for arc datasets * update difficulty metadata for algorithmic datasets * update difficulty metadata for algebra datasets * use tuples * update tests * update tests
This commit is contained in:
parent
b69c35818a
commit
7475a20700
80 changed files with 304 additions and 126 deletions
|
|
@ -110,8 +110,8 @@ class BaseConversionDataset(ProceduralDataset):
|
|||
"source_repr": source_repr,
|
||||
"target_repr": target_repr,
|
||||
"difficulty": {
|
||||
"value": value,
|
||||
"base": (source_base, target_base),
|
||||
"base": (self.config.min_base, self.config.max_base),
|
||||
"value": (self.config.min_value, self.config.max_value),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,7 +108,10 @@ class BinaryAlternationDataset(ProceduralDataset):
|
|||
"string": string,
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
"difficulty": {"n": n},
|
||||
"n": n,
|
||||
"difficulty": {
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -130,8 +130,9 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"n": n,
|
||||
"difficulty": {
|
||||
"n": n,
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
"p_zero": self.config.p_zero,
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -80,9 +80,10 @@ class CaesarCipherDataset(ProceduralDataset):
|
|||
"rotation": rotation,
|
||||
"cipher_text": cipher_text,
|
||||
"clear_text": sentence,
|
||||
"num_words": num_words,
|
||||
"difficulty": {
|
||||
"rotation": rotation,
|
||||
"words": num_words,
|
||||
"words": (self.config.min_words, self.config.max_words),
|
||||
"rotation": (self.config.min_rotation, self.config.max_rotation),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,8 +64,9 @@ class CountPrimesDataset(ProceduralDataset):
|
|||
"end": end,
|
||||
"primes": primes,
|
||||
"solution": answer,
|
||||
"n": (start, end),
|
||||
"difficulty": {
|
||||
"n": (start, end),
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,8 +187,7 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
"digit_to_letter": digit_to_letter,
|
||||
"letter_to_digit": letter_to_digit,
|
||||
"difficulty": {
|
||||
"min_words": self.config.min_words,
|
||||
"max_words": self.config.max_words,
|
||||
"words": (self.config.min_words, self.config.max_words),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -368,6 +368,13 @@ class GameOfLifeHaltingDataset(ProceduralDataset):
|
|||
"placed_patterns": placed_patterns,
|
||||
"simulation_steps": self.config.max_simulation_steps,
|
||||
"should_oscillate": should_oscillate,
|
||||
"difficulty": {
|
||||
"grid_size_x": self.config.grid_size_x,
|
||||
"grid_size_y": self.config.grid_size_y,
|
||||
"difficulty": self.config.difficulty,
|
||||
"num_oscillators": self.config.num_oscillators,
|
||||
"max_simulation_steps": self.config.max_simulation_steps,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -215,7 +215,11 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
|
|||
"metadata": {
|
||||
"possible_answer": solution,
|
||||
"puzzle": puzzle,
|
||||
"difficulty": {"num_vertices": num_vertices, "num_colors": num_colors},
|
||||
"num_vertices": num_vertices,
|
||||
"difficulty": {
|
||||
"num_vertices": (self.config.min_num_vertices, self.config.max_num_vertices),
|
||||
"num_colors": num_colors,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -117,8 +117,10 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"words": words,
|
||||
"solution": answer,
|
||||
"anagram_groups": anagram_groups,
|
||||
"difficulty": {
|
||||
"anagram_groups": anagram_groups,
|
||||
"anagram_groups": (self.config.min_anagram_groups, self.config.max_anagram_groups),
|
||||
"words_per_group": (self.config.min_words_per_group, self.config.max_words_per_group),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,8 +110,9 @@ class IsomorphicStringsDataset(ProceduralDataset):
|
|||
"words": [s, t],
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
"string_length": string_length,
|
||||
"difficulty": {
|
||||
"string_length": string_length,
|
||||
"string_length": (self.config.min_string_length, self.config.max_string_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,9 @@ class LetterCountingDataset(ProceduralDataset):
|
|||
"span_length": span_length,
|
||||
"target_letter": target_letter,
|
||||
"span": span,
|
||||
"difficulty": {"words": span_length},
|
||||
"difficulty": {
|
||||
"words": (self.config.min_words, self.config.max_words),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -110,8 +110,8 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
"original_words": selected_words,
|
||||
"difficulty": {
|
||||
"word_len": (self.config.min_word_len, self.config.max_word_len),
|
||||
"words": num_words,
|
||||
"corruption_level": corruption_level,
|
||||
"words": (self.config.min_words, self.config.max_words),
|
||||
"corruption_level": (self.config.min_corruption_level, self.config.max_corruption_level),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -309,10 +309,13 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"operations": operations,
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
"num_transforms": num_transforms,
|
||||
"difficulty": {
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
"num_transforms": num_transforms,
|
||||
"rows": (self.config.min_rows, self.config.max_rows),
|
||||
"cols": (self.config.min_cols, self.config.max_cols),
|
||||
"num_transforms": (self.config.min_transforms, self.config.max_transforms),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,8 +95,9 @@ class NumberFilteringDataset(ProceduralDataset):
|
|||
"filter_value": filter_str,
|
||||
"operation": f"{keep_remove}_{larger_smaller}",
|
||||
"result": result_strs,
|
||||
"numbers": len(numbers),
|
||||
"difficulty": {
|
||||
"numbers": len(numbers),
|
||||
"numbers": (self.config.min_numbers, self.config.max_numbers),
|
||||
"decimals": (self.config.min_decimals, self.config.max_decimals),
|
||||
"value": (self.config.min_value, self.config.max_value),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -93,8 +93,9 @@ Please follow the instruction below:
|
|||
"original_numbers": number_strs,
|
||||
"direction": direction,
|
||||
"sorted_numbers": answer,
|
||||
"numbers": count,
|
||||
"difficulty": {
|
||||
"numbers": count,
|
||||
"numbers": (self.config.min_numbers, self.config.max_numbers),
|
||||
"decimals": (self.config.min_decimals, self.config.max_decimals),
|
||||
"value": (self.config.min_value, self.config.max_value),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -69,8 +69,9 @@ class PalindromeDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"letters": scrambled_letters,
|
||||
"generated_palindrome": palindrome,
|
||||
"length": length,
|
||||
"difficulty": {
|
||||
"length": length,
|
||||
"length": (self.config.min_length, self.config.max_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -140,8 +140,13 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"string_len": string_len,
|
||||
"difficulty": {
|
||||
"string_len": string_len,
|
||||
"string_len": (self.config.min_string_len, self.config.max_string_len),
|
||||
"substring_palindrome_len": (
|
||||
self.config.min_substring_palindrome_len,
|
||||
self.config.max_substring_palindrome_len,
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,10 +117,13 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
"pool_type": pool_type,
|
||||
"pool_size": pool_size,
|
||||
"solution": answer.tolist(),
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
"pool_size": pool_size,
|
||||
"difficulty": {
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
"pool_size": pool_size,
|
||||
"rows": (self.config.min_rows, self.config.max_rows),
|
||||
"cols": (self.config.min_cols, self.config.max_cols),
|
||||
"pool_size": (self.config.min_pool_size, self.config.max_pool_size),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,9 +103,11 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
"magazine": magazine,
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
"note_length": note_length,
|
||||
"magazine_length": magazine_length,
|
||||
"difficulty": {
|
||||
"note_length": note_length,
|
||||
"magazine_length": magazine_length,
|
||||
"note_length": (self.config.min_note_length, self.config.max_note_length),
|
||||
"magazine_length": (self.config.min_magazine_length, self.config.max_magazine_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,9 +86,10 @@ class RotateMatrixDataset(ProceduralDataset):
|
|||
"matrix": matrix,
|
||||
"num_rotations": num_rotations,
|
||||
"solution": answer,
|
||||
"n": n,
|
||||
"difficulty": {
|
||||
"n": n,
|
||||
"num_rotations": num_rotations,
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
"num_rotations": (self.config.min_rotations, self.config.max_rotations),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -122,7 +122,10 @@ class RottenOrangesDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"difficulty": {"n": n},
|
||||
"n": n,
|
||||
"difficulty": {
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,12 @@ class SentenceReorderingDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": f"Restore the correct order of words in the following sentence: {question}",
|
||||
"answer": solved_sentence,
|
||||
"metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}},
|
||||
"metadata": {
|
||||
"word_count": word_count,
|
||||
"difficulty": {
|
||||
"words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"word": word,
|
||||
"word_len": len(word),
|
||||
"difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)},
|
||||
"difficulty": {
|
||||
"word_len": (self.config.min_word_len, self.config.max_word_len),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
"""Generate a single Spiral Matrix question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
n = rng.randint(2, self.config.max_n)
|
||||
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||
matrix = self._get_matrix(rng, n)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
answer = self._get_spiral(matrix)
|
||||
|
|
@ -113,7 +113,10 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"difficulty": {"n": n},
|
||||
"n": n,
|
||||
"difficulty": {
|
||||
"n": (self.config.min_n, self.config.max_n),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,8 +104,9 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"string_length": string_length,
|
||||
"difficulty": {
|
||||
"string_length": string_length,
|
||||
"string_length": (self.config.min_string_length, self.config.max_string_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -183,9 +183,11 @@ class StringManipulationDataset(ProceduralDataset):
|
|||
"solution": answer,
|
||||
"states": states,
|
||||
"selected_rules": [rule for rule, _ in selected_rules],
|
||||
"string_length": string_length,
|
||||
"num_rules": num_rules,
|
||||
"difficulty": {
|
||||
"string_length": string_length,
|
||||
"num_rules": num_rules,
|
||||
"string_length": (self.config.min_string_length, self.config.max_string_length),
|
||||
"num_rules": (self.config.min_num_rules, self.config.max_num_rules),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -127,8 +127,9 @@ class StringSplittingDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"initial_machines": (A_machine, B_machine, C_machine),
|
||||
"difficulty": {
|
||||
"initial_machines": (A_machine, B_machine, C_machine),
|
||||
"initial_machines": (self.config.min_initial_machines, self.config.max_initial_machines),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -132,8 +132,9 @@ class StringSynthesisDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"states": states,
|
||||
"solution": answer,
|
||||
"initial_blocks": (A_square, B_square, C_square),
|
||||
"difficulty": {
|
||||
"initial_blocks": (A_square, B_square, C_square),
|
||||
"initial_blocks": (self.config.min_initial_blocks, self.config.max_initial_blocks),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -224,8 +224,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
"word_length": length,
|
||||
"chain_length": len(path),
|
||||
"difficulty": {
|
||||
"word_length": length,
|
||||
"chain_length": len(path),
|
||||
"word_length": (self.config.min_word_length, self.config.max_word_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,13 @@ class WordSequenceReversalDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
|
||||
"answer": answer,
|
||||
"metadata": {"num_words": num_words, "words": words, "difficulty": {"words": num_words}},
|
||||
"metadata": {
|
||||
"num_words": num_words,
|
||||
"words": words,
|
||||
"difficulty": {
|
||||
"words": (self.config.min_words, self.config.max_words),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -106,14 +106,16 @@ class WordSortingDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
|
||||
"answer": ", ".join(answer),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
"num_words": len(original_words),
|
||||
"word_length": max(len(word) for word in original_words),
|
||||
},
|
||||
"original_words": original_words,
|
||||
"sorted_words": answer,
|
||||
"transformed_words": transformed_words,
|
||||
"direction": direction,
|
||||
"num_words": len(original_words),
|
||||
"word_length": max(len(word) for word in original_words),
|
||||
"difficulty": {
|
||||
"num_words": (self.config.min_words, self.config.max_words),
|
||||
"word_length": (self.config.min_word_length, self.config.max_word_length),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue