include ranges rather than sampled values in difficulty metadata dicts (#387)

* update difficulty metadata for logic datasets

* update difficulty metadata for graph datasets

* update difficulty metadata for geometry datasets

* update difficulty metadata for games datasets

* update difficulty metadata for cognition datasets

* update difficulty metadata for arithmetic datasets

* update difficulty metadata for arc datasets

* update difficulty metadata for algorithmic datasets

* update difficulty metadata for algebra datasets

* use tuples

* update tests

* update tests
This commit is contained in:
Oliver Stanley 2025-03-20 09:27:03 +00:00 committed by GitHub
parent b69c35818a
commit 7475a20700
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 304 additions and 126 deletions

View file

@ -124,7 +124,11 @@ In solving equations, please follow these instructions:
"variable": variable, "variable": variable,
"degree": degree, "degree": degree,
"real_solutions": real_solutions, "real_solutions": real_solutions,
"difficulty": {"terms": num_terms, "degree": degree}, "num_terms": num_terms,
"difficulty": {
"terms": (self.config.min_terms, self.config.max_terms),
"degree": (self.config.min_degree, self.config.max_degree),
},
}, },
} }

View file

@ -85,7 +85,10 @@ When performing calculations, please follow these guidelines:
"integrand": str(derivative), "integrand": str(derivative),
"variable": str(symbol), "variable": str(symbol),
"expected_answer_expression": polynomial, "expected_answer_expression": polynomial,
"difficulty": {"terms": num_terms}, "num_terms": num_terms,
"difficulty": {
"terms": (self.config.min_terms, self.config.max_terms),
},
}, },
} }

View file

@ -110,8 +110,8 @@ class BaseConversionDataset(ProceduralDataset):
"source_repr": source_repr, "source_repr": source_repr,
"target_repr": target_repr, "target_repr": target_repr,
"difficulty": { "difficulty": {
"value": value, "base": (self.config.min_base, self.config.max_base),
"base": (source_base, target_base), "value": (self.config.min_value, self.config.max_value),
}, },
}, },
} }

View file

@ -108,7 +108,10 @@ class BinaryAlternationDataset(ProceduralDataset):
"string": string, "string": string,
"solution": answer, "solution": answer,
"solvable": solvable, "solvable": solvable,
"difficulty": {"n": n}, "n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
}, },
} }

View file

@ -130,8 +130,9 @@ class BinaryMatrixDataset(ProceduralDataset):
"metadata": { "metadata": {
"matrix": matrix, "matrix": matrix,
"solution": answer, "solution": answer,
"n": n,
"difficulty": { "difficulty": {
"n": n, "n": (self.config.min_n, self.config.max_n),
"p_zero": self.config.p_zero, "p_zero": self.config.p_zero,
}, },
}, },

View file

@ -80,9 +80,10 @@ class CaesarCipherDataset(ProceduralDataset):
"rotation": rotation, "rotation": rotation,
"cipher_text": cipher_text, "cipher_text": cipher_text,
"clear_text": sentence, "clear_text": sentence,
"num_words": num_words,
"difficulty": { "difficulty": {
"rotation": rotation, "words": (self.config.min_words, self.config.max_words),
"words": num_words, "rotation": (self.config.min_rotation, self.config.max_rotation),
}, },
}, },
} }

View file

@ -64,8 +64,9 @@ class CountPrimesDataset(ProceduralDataset):
"end": end, "end": end,
"primes": primes, "primes": primes,
"solution": answer, "solution": answer,
"n": (start, end),
"difficulty": { "difficulty": {
"n": (start, end), "n": (self.config.min_n, self.config.max_n),
}, },
}, },
} }

View file

@ -187,8 +187,7 @@ class CryptarithmDataset(ProceduralDataset):
"digit_to_letter": digit_to_letter, "digit_to_letter": digit_to_letter,
"letter_to_digit": letter_to_digit, "letter_to_digit": letter_to_digit,
"difficulty": { "difficulty": {
"min_words": self.config.min_words, "words": (self.config.min_words, self.config.max_words),
"max_words": self.config.max_words,
}, },
}, },
} }

View file

@ -368,6 +368,13 @@ class GameOfLifeHaltingDataset(ProceduralDataset):
"placed_patterns": placed_patterns, "placed_patterns": placed_patterns,
"simulation_steps": self.config.max_simulation_steps, "simulation_steps": self.config.max_simulation_steps,
"should_oscillate": should_oscillate, "should_oscillate": should_oscillate,
"difficulty": {
"grid_size_x": self.config.grid_size_x,
"grid_size_y": self.config.grid_size_y,
"difficulty": self.config.difficulty,
"num_oscillators": self.config.num_oscillators,
"max_simulation_steps": self.config.max_simulation_steps,
},
}, },
} }

View file

@ -215,7 +215,11 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
"metadata": { "metadata": {
"possible_answer": solution, "possible_answer": solution,
"puzzle": puzzle, "puzzle": puzzle,
"difficulty": {"num_vertices": num_vertices, "num_colors": num_colors}, "num_vertices": num_vertices,
"difficulty": {
"num_vertices": (self.config.min_num_vertices, self.config.max_num_vertices),
"num_colors": num_colors,
},
}, },
} }

View file

@ -117,8 +117,10 @@ class GroupAnagramsDataset(ProceduralDataset):
"metadata": { "metadata": {
"words": words, "words": words,
"solution": answer, "solution": answer,
"anagram_groups": anagram_groups,
"difficulty": { "difficulty": {
"anagram_groups": anagram_groups, "anagram_groups": (self.config.min_anagram_groups, self.config.max_anagram_groups),
"words_per_group": (self.config.min_words_per_group, self.config.max_words_per_group),
}, },
}, },
} }

View file

@ -110,8 +110,9 @@ class IsomorphicStringsDataset(ProceduralDataset):
"words": [s, t], "words": [s, t],
"solution": answer, "solution": answer,
"solvable": solvable, "solvable": solvable,
"string_length": string_length,
"difficulty": { "difficulty": {
"string_length": string_length, "string_length": (self.config.min_string_length, self.config.max_string_length),
}, },
}, },
} }

View file

@ -67,7 +67,9 @@ class LetterCountingDataset(ProceduralDataset):
"span_length": span_length, "span_length": span_length,
"target_letter": target_letter, "target_letter": target_letter,
"span": span, "span": span,
"difficulty": {"words": span_length}, "difficulty": {
"words": (self.config.min_words, self.config.max_words),
},
}, },
} }

View file

@ -110,8 +110,8 @@ class LetterJumbleDataset(ProceduralDataset):
"original_words": selected_words, "original_words": selected_words,
"difficulty": { "difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len), "word_len": (self.config.min_word_len, self.config.max_word_len),
"words": num_words, "words": (self.config.min_words, self.config.max_words),
"corruption_level": corruption_level, "corruption_level": (self.config.min_corruption_level, self.config.max_corruption_level),
}, },
}, },
} }

View file

@ -309,10 +309,13 @@ class ManipulateMatrixDataset(ProceduralDataset):
"matrix": matrix, "matrix": matrix,
"solution": answer, "solution": answer,
"operations": operations, "operations": operations,
"rows": rows,
"cols": cols,
"num_transforms": num_transforms,
"difficulty": { "difficulty": {
"rows": rows, "rows": (self.config.min_rows, self.config.max_rows),
"cols": cols, "cols": (self.config.min_cols, self.config.max_cols),
"num_transforms": num_transforms, "num_transforms": (self.config.min_transforms, self.config.max_transforms),
}, },
}, },
} }

View file

@ -95,8 +95,9 @@ class NumberFilteringDataset(ProceduralDataset):
"filter_value": filter_str, "filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}", "operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs, "result": result_strs,
"numbers": len(numbers),
"difficulty": { "difficulty": {
"numbers": len(numbers), "numbers": (self.config.min_numbers, self.config.max_numbers),
"decimals": (self.config.min_decimals, self.config.max_decimals), "decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value), "value": (self.config.min_value, self.config.max_value),
}, },

View file

@ -93,8 +93,9 @@ Please follow the instruction below:
"original_numbers": number_strs, "original_numbers": number_strs,
"direction": direction, "direction": direction,
"sorted_numbers": answer, "sorted_numbers": answer,
"numbers": count,
"difficulty": { "difficulty": {
"numbers": count, "numbers": (self.config.min_numbers, self.config.max_numbers),
"decimals": (self.config.min_decimals, self.config.max_decimals), "decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value), "value": (self.config.min_value, self.config.max_value),
}, },

View file

@ -69,8 +69,9 @@ class PalindromeDataset(ProceduralDataset):
"metadata": { "metadata": {
"letters": scrambled_letters, "letters": scrambled_letters,
"generated_palindrome": palindrome, "generated_palindrome": palindrome,
"length": length,
"difficulty": { "difficulty": {
"length": length, "length": (self.config.min_length, self.config.max_length),
}, },
}, },
} }

View file

@ -140,8 +140,13 @@ class PalindromePartitioningDataset(ProceduralDataset):
"metadata": { "metadata": {
"string": string, "string": string,
"solution": answer, "solution": answer,
"string_len": string_len,
"difficulty": { "difficulty": {
"string_len": string_len, "string_len": (self.config.min_string_len, self.config.max_string_len),
"substring_palindrome_len": (
self.config.min_substring_palindrome_len,
self.config.max_substring_palindrome_len,
),
}, },
}, },
} }

View file

@ -117,10 +117,13 @@ class PoolMatrixDataset(ProceduralDataset):
"pool_type": pool_type, "pool_type": pool_type,
"pool_size": pool_size, "pool_size": pool_size,
"solution": answer.tolist(), "solution": answer.tolist(),
"rows": rows,
"cols": cols,
"pool_size": pool_size,
"difficulty": { "difficulty": {
"rows": rows, "rows": (self.config.min_rows, self.config.max_rows),
"cols": cols, "cols": (self.config.min_cols, self.config.max_cols),
"pool_size": pool_size, "pool_size": (self.config.min_pool_size, self.config.max_pool_size),
}, },
}, },
} }

View file

@ -103,9 +103,11 @@ class RansomNoteDataset(ProceduralDataset):
"magazine": magazine, "magazine": magazine,
"solution": answer, "solution": answer,
"solvable": solvable, "solvable": solvable,
"note_length": note_length,
"magazine_length": magazine_length,
"difficulty": { "difficulty": {
"note_length": note_length, "note_length": (self.config.min_note_length, self.config.max_note_length),
"magazine_length": magazine_length, "magazine_length": (self.config.min_magazine_length, self.config.max_magazine_length),
}, },
}, },
} }

View file

@ -86,9 +86,10 @@ class RotateMatrixDataset(ProceduralDataset):
"matrix": matrix, "matrix": matrix,
"num_rotations": num_rotations, "num_rotations": num_rotations,
"solution": answer, "solution": answer,
"n": n,
"difficulty": { "difficulty": {
"n": n, "n": (self.config.min_n, self.config.max_n),
"num_rotations": num_rotations, "num_rotations": (self.config.min_rotations, self.config.max_rotations),
}, },
}, },
} }

View file

@ -122,7 +122,10 @@ class RottenOrangesDataset(ProceduralDataset):
"metadata": { "metadata": {
"matrix": matrix, "matrix": matrix,
"solution": answer, "solution": answer,
"difficulty": {"n": n}, "n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
}, },
} }

View file

@ -90,7 +90,12 @@ class SentenceReorderingDataset(ProceduralDataset):
return { return {
"question": f"Restore the correct order of words in the following sentence: {question}", "question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence, "answer": solved_sentence,
"metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}}, "metadata": {
"word_count": word_count,
"difficulty": {
"words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -54,7 +54,9 @@ class SpellBackwardDataset(ProceduralDataset):
"metadata": { "metadata": {
"word": word, "word": word,
"word_len": len(word), "word_len": len(word),
"difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)}, "difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len),
},
}, },
} }

View file

@ -101,7 +101,7 @@ class SpiralMatrixDataset(ProceduralDataset):
"""Generate a single Spiral Matrix question""" """Generate a single Spiral Matrix question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
n = rng.randint(2, self.config.max_n) n = rng.randint(self.config.min_n, self.config.max_n)
matrix = self._get_matrix(rng, n) matrix = self._get_matrix(rng, n)
matrix_str = self._matrix_to_str(matrix) matrix_str = self._matrix_to_str(matrix)
answer = self._get_spiral(matrix) answer = self._get_spiral(matrix)
@ -113,7 +113,10 @@ class SpiralMatrixDataset(ProceduralDataset):
"metadata": { "metadata": {
"matrix": matrix, "matrix": matrix,
"solution": answer, "solution": answer,
"difficulty": {"n": n}, "n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
}, },
} }

View file

@ -104,8 +104,9 @@ class StringInsertionDataset(ProceduralDataset):
"metadata": { "metadata": {
"string": string, "string": string,
"solution": answer, "solution": answer,
"string_length": string_length,
"difficulty": { "difficulty": {
"string_length": string_length, "string_length": (self.config.min_string_length, self.config.max_string_length),
}, },
}, },
} }

View file

@ -183,9 +183,11 @@ class StringManipulationDataset(ProceduralDataset):
"solution": answer, "solution": answer,
"states": states, "states": states,
"selected_rules": [rule for rule, _ in selected_rules], "selected_rules": [rule for rule, _ in selected_rules],
"string_length": string_length,
"num_rules": num_rules,
"difficulty": { "difficulty": {
"string_length": string_length, "string_length": (self.config.min_string_length, self.config.max_string_length),
"num_rules": num_rules, "num_rules": (self.config.min_num_rules, self.config.max_num_rules),
}, },
}, },
} }

View file

@ -127,8 +127,9 @@ class StringSplittingDataset(ProceduralDataset):
"metadata": { "metadata": {
"states": states, "states": states,
"solution": answer, "solution": answer,
"initial_machines": (A_machine, B_machine, C_machine),
"difficulty": { "difficulty": {
"initial_machines": (A_machine, B_machine, C_machine), "initial_machines": (self.config.min_initial_machines, self.config.max_initial_machines),
}, },
}, },
} }

View file

@ -132,8 +132,9 @@ class StringSynthesisDataset(ProceduralDataset):
"metadata": { "metadata": {
"states": states, "states": states,
"solution": answer, "solution": answer,
"initial_blocks": (A_square, B_square, C_square),
"difficulty": { "difficulty": {
"initial_blocks": (A_square, B_square, C_square), "initial_blocks": (self.config.min_initial_blocks, self.config.max_initial_blocks),
}, },
}, },
} }

View file

@ -224,8 +224,7 @@ class WordLadderDataset(ProceduralDataset):
"word_length": length, "word_length": length,
"chain_length": len(path), "chain_length": len(path),
"difficulty": { "difficulty": {
"word_length": length, "word_length": (self.config.min_word_length, self.config.max_word_length),
"chain_length": len(path),
}, },
}, },
} }

View file

@ -62,7 +62,13 @@ class WordSequenceReversalDataset(ProceduralDataset):
return { return {
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}", "question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
"answer": answer, "answer": answer,
"metadata": {"num_words": num_words, "words": words, "difficulty": {"words": num_words}}, "metadata": {
"num_words": num_words,
"words": words,
"difficulty": {
"words": (self.config.min_words, self.config.max_words),
},
},
} }

View file

@ -106,14 +106,16 @@ class WordSortingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)), "question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer), "answer": ", ".join(answer),
"metadata": { "metadata": {
"difficulty": {
"num_words": len(original_words),
"word_length": max(len(word) for word in original_words),
},
"original_words": original_words, "original_words": original_words,
"sorted_words": answer, "sorted_words": answer,
"transformed_words": transformed_words, "transformed_words": transformed_words,
"direction": direction, "direction": direction,
"num_words": len(original_words),
"word_length": max(len(word) for word in original_words),
"difficulty": {
"num_words": (self.config.min_words, self.config.max_words),
"word_length": (self.config.min_word_length, self.config.max_word_length),
},
}, },
} }

View file

@ -117,9 +117,11 @@ class ReArcDataset(ProceduralDataset):
"input": task["input"], "input": task["input"],
"output": task["output"], "output": task["output"],
"task_id": task_id, "task_id": task_id,
"rng": rng_difficulty,
"pso": pso_difficulty,
"difficulty": { "difficulty": {
"rng": rng_difficulty, "rng_difficulty": self.config.rng_difficulty_weights,
"pso": pso_difficulty, "pso_difficulty": self.config.pso_difficulty_weights,
}, },
}, },
} }

View file

@ -96,7 +96,12 @@ class BasicArithmeticDataset(ProceduralDataset):
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"expression": expression, "expression": expression,
"difficulty": {"num_terms": num_terms, "num_digits": num_digits}, "num_terms": num_terms,
"num_digits": num_digits,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
}, },
} }

View file

@ -64,11 +64,13 @@ class ChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =", "question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"difficulty": { "num_terms": num_terms,
"num_terms": num_terms, "num_digits": num_digits,
"num_digits": num_digits,
},
"expression": expression, "expression": expression,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
}, },
} }

View file

@ -46,7 +46,10 @@ class CountBitsDataset(ProceduralDataset):
"number": number, "number": number,
"solution": answer, "solution": answer,
"binary": binary, "binary": binary,
"difficulty": {"n": number}, "n": number,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
}, },
} }

View file

@ -189,9 +189,11 @@ class DecimalArithmeticDataset(ProceduralDataset):
"question": problem_str, "question": problem_str,
"answer": str(answer), "answer": str(answer),
"metadata": { "metadata": {
"decimal_places": decimal_places,
"num_terms": terms,
"difficulty": { "difficulty": {
"decimal_places": decimal_places, "decimal_places": (self.config.min_num_decimal_places, self.config.max_num_decimal_places),
"num_terms": terms, "num_terms": (self.config.min_terms, self.config.max_terms),
}, },
}, },
} }

View file

@ -66,11 +66,14 @@ class DecimalChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =", "question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"difficulty": { "num_terms": num_terms,
"num_terms": num_terms, "num_digits": num_digits,
"num_digits": num_digits,
},
"expression": expression, "expression": expression,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
"decimal_places": (self.config.min_decimal_places, self.config.max_decimal_places),
},
}, },
} }

View file

@ -124,11 +124,11 @@ class DiceDataset(ProceduralDataset):
"question": puzzle_str, "question": puzzle_str,
"answer": answer_str, "answer": answer_str,
"metadata": { "metadata": {
"puzzle": puzzle,
"difficulty": { "difficulty": {
"num_dice": self.config.num_dice, "num_dice": self.config.num_dice,
"max_dice_size": self.config.max_dice_size, "max_dice_size": self.config.max_dice_size,
}, },
"puzzle": puzzle,
}, },
} }

View file

@ -120,9 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_denominator": simple_den, "simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den "reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style, "style": style,
"factor": factor,
"difficulty": { "difficulty": {
"factor": factor, "value": (self.config.min_value, self.config.max_value),
"value": (simple_num, simple_den), "factor": (self.config.min_factor, self.config.max_factor),
}, },
}, },
} }

View file

@ -64,9 +64,10 @@ class GCDDataset(ProceduralDataset):
"metadata": { "metadata": {
"numbers": numbers, "numbers": numbers,
"result": result, "result": result,
"num_terms": num_terms,
"difficulty": { "difficulty": {
"num_terms": num_terms, "num_terms": (self.config.min_numbers, self.config.max_numbers),
"max_value": self.config.max_value, "max_value": (self.config.min_value, self.config.max_value),
}, },
}, },
} }

View file

@ -67,7 +67,7 @@ class LCMDataset(ProceduralDataset):
"numbers": numbers, "numbers": numbers,
"result": result, "result": result,
"difficulty": { "difficulty": {
"numbers": len(numbers), "numbers": (self.config.min_numbers, self.config.max_numbers),
"value": (self.config.min_value, self.config.max_value), "value": (self.config.min_value, self.config.max_value),
}, },
}, },

View file

@ -118,11 +118,13 @@ class LegCountingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)), "question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
"answer": str(total_legs), "answer": str(total_legs),
"metadata": { "metadata": {
"difficulty": {
"num_animals": len(animals),
},
"animals": animals, "animals": animals,
"num_animals": len(animals),
"total_legs": total_legs, "total_legs": total_legs,
"difficulty": {
"num_animals": (self.config.min_animals, self.config.max_animals),
"num_instances": (self.config.min_instances, self.config.max_instances),
},
}, },
} }

View file

@ -98,8 +98,9 @@ class NumberFormatDataset(ProceduralDataset):
"solution": answer, "solution": answer,
"formatted_candidates": formatted_candidates, "formatted_candidates": formatted_candidates,
"size": size, "size": size,
"num_candidates": num_candidates,
"difficulty": { "difficulty": {
"num_candidates": num_candidates, "num_candidates": (self.config.min_num_candidates, self.config.max_num_candidates),
"n": (self.config.min_n, self.config.max_n), "n": (self.config.min_n, self.config.max_n),
"min_delta": self.config.max_delta, "min_delta": self.config.max_delta,
}, },

View file

@ -73,7 +73,14 @@ class PowerFunctionDataset(ProceduralDataset):
return { return {
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent), "question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer), "answer": str(answer),
"metadata": {"base": base, "exponent": exponent, "solution": answer, "difficulty": {"exponent": exponent}}, "metadata": {
"base": base,
"exponent": exponent,
"solution": answer,
"difficulty": {
"exponent": (self.config.min_exponent, self.config.max_exponent),
},
},
} }

View file

@ -83,7 +83,13 @@ class PrimeFactorizationDataset(ProceduralDataset):
f"(Example: for 12 the answer would be: 2 × 2 × 3)" f"(Example: for 12 the answer would be: 2 × 2 × 3)"
), ),
"answer": answer, "answer": answer,
"metadata": {"number": number, "factors": factors, "difficulty": {"value": number}}, "metadata": {
"number": number,
"factors": factors,
"difficulty": {
"value": (self.config.min_value, self.config.max_value),
},
},
} }

View file

@ -66,11 +66,13 @@ class ProductsDataset(ProceduralDataset):
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.", "question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"difficulty": {
"num_terms": num_terms,
"num_digits": num_digits,
},
"expression": expression, "expression": expression,
"num_terms": num_terms,
"num_digits": num_digits,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
}, },
} }

View file

@ -141,7 +141,9 @@ class ColorCubeRotationDataset(ProceduralDataset):
"rotations": [r.value for r in rotations], "rotations": [r.value for r in rotations],
"target_side": target_side.value, "target_side": target_side.value,
"num_rotations": num_rotations, "num_rotations": num_rotations,
"difficulty": {"rotations": num_rotations}, "difficulty": {
"rotations": (self.config.min_rotations, self.config.max_rotations),
},
}, },
} }

View file

@ -188,7 +188,9 @@ class FigletFontDataset(ProceduralDataset):
"metadata": { "metadata": {
"font": chosen_font, "font": chosen_font,
"space_letters": self.config.space_letters, "space_letters": self.config.space_letters,
"difficulty": {"word_len": len(word)}, "difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len),
},
}, },
} }

View file

@ -140,9 +140,11 @@ class ModuloGridDataset(ProceduralDataset):
"target": target, "target": target,
"operation": operation, "operation": operation,
"difficulty": { "difficulty": {
"holes": self.config.max_holes,
"size_x": self.config.size_x, "size_x": self.config.size_x,
"size_y": self.config.size_y, "size_y": self.config.size_y,
"holes": self.config.max_holes,
"divisor": self.config.max_divisor,
"target": self.config.max_target,
}, },
}, },
} }

View file

@ -103,7 +103,13 @@ class NeedleHaystackDataset(ProceduralDataset):
return { return {
"question": full_text, "question": full_text,
"answer": stack["needle"][0], "answer": stack["needle"][0],
"metadata": {"question": question, "difficulty": {"num_statements": num_statements}}, "metadata": {
"question": question,
"num_statements": num_statements,
"difficulty": {
"num_statements": (self.config.min_num_statements, self.config.max_num_statements),
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -195,7 +195,14 @@ class NumberSequenceDataset(ProceduralDataset):
return { return {
"question": ", ".join(map(str, visible_terms)) + ", ?", "question": ", ".join(map(str, visible_terms)) + ", ?",
"answer": str(sequence[-1]), "answer": str(sequence[-1]),
"metadata": {"rule": rule.to_string(), "complexity": complexity, "sequence": sequence}, "metadata": {
"rule": rule.to_string(),
"complexity": complexity,
"sequence": sequence,
"difficulty": {
"max_complexity": self.config.max_complexity,
},
},
} }

View file

@ -117,7 +117,14 @@ class RectangleCountDataset(ProceduralDataset):
return { return {
"question": QUESTION_TEMPLATE.format(puzzle=puzzle), "question": QUESTION_TEMPLATE.format(puzzle=puzzle),
"answer": str(answer), "answer": str(answer),
"metadata": {"puzzle": puzzle, "solution": answer, "difficulty": {"max_rectangles": target}}, "metadata": {
"puzzle": puzzle,
"solution": answer,
"num_rectangles": target,
"difficulty": {
"max_rectangles": self.config.max_rectangles,
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -110,8 +110,8 @@ class RubiksCubeDataset(ProceduralDataset):
"scramble_moves": " ".join([str(move) for move in scramble_moves]), "scramble_moves": " ".join([str(move) for move in scramble_moves]),
"example_correct_answer": actions_string, "example_correct_answer": actions_string,
"difficulty": { "difficulty": {
"scramble_steps": num_steps,
"cube_size": self.config.cube_size, "cube_size": self.config.cube_size,
"scramble_steps": (self.config.min_scramble_steps, self.config.max_scramble_steps),
}, },
}, },
} }

View file

@ -126,11 +126,14 @@ class BoxnetDataset(ProceduralDataset):
"question": question, "question": question,
"answer": None, "answer": None,
"metadata": { "metadata": {
"difficulty": { "row_num": row_num,
"row_num": row_num, "column_num": column_num,
"column_num": column_num,
},
"initial_state": pg_dict, "initial_state": pg_dict,
"difficulty": {
"row_num": (self.config.min_row_num, self.config.max_row_num),
"column_num": (self.config.min_column_num, self.config.max_column_num),
"box_num": (self.config.min_box_num, self.config.max_box_num),
},
}, },
} }

View file

@ -194,7 +194,10 @@ class EmojiMysteryDataset(ProceduralDataset):
"answer": secret_sentence, "answer": secret_sentence,
"metadata": { "metadata": {
"emoji": secret_emoji, "emoji": secret_emoji,
"difficulty": {"num_words_in_sentence": len(re.findall(r"\b\w+\b", secret_sentence))}, "num_words_in_sentence": len(re.findall(r"\b\w+\b", secret_sentence)),
"difficulty": {
"num_words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
},
}, },
} }

View file

@ -84,7 +84,12 @@ class FutoshikiDataset(ProceduralDataset):
"puzzle": puzzle, "puzzle": puzzle,
"constraints": constraints, "constraints": constraints,
"solution": solution, "solution": solution,
"difficulty": {"board_size": board_size, "difficulty": difficulty}, "board_size": board_size,
"difficulty_rating": difficulty,
"difficulty": {
"board_size": (self.config.min_board_size, self.config.max_board_size),
"difficulty": (self.config.min_difficulty, self.config.max_difficulty),
},
}, },
} }

View file

@ -122,7 +122,9 @@ class MahjongPuzzleDataset(ProceduralDataset):
"metadata": { "metadata": {
"rounds": rounds, "rounds": rounds,
"solution": answer, "solution": answer,
"difficulty": {"num_rounds": num_rounds}, "difficulty": {
"num_rounds": (self.config.min_num_rounds, self.config.max_num_rounds),
},
}, },
} }

View file

@ -112,8 +112,8 @@ class MazeDataset(ProceduralDataset):
"wall": self.wall_char, "wall": self.wall_char,
"path": self.path_char, "path": self.path_char,
"difficulty": { "difficulty": {
"dist": dist, "dist": (self.config.min_dist, self.config.max_dist),
"grid_size": size, "grid_size": (self.config.min_grid_size, self.config.max_grid_size),
}, },
}, },
} }

View file

@ -197,7 +197,7 @@ class MiniSudokuDataset(ProceduralDataset):
"solution": solved_board, "solution": solved_board,
"num_empty": num_empty, "num_empty": num_empty,
"difficulty": { "difficulty": {
"empty": num_empty, "empty": (self.config.min_empty, self.config.max_empty),
}, },
}, },
} }

View file

@ -137,7 +137,7 @@ class NQueensDataset(ProceduralDataset):
"valid_answers": valid_solutions_str, "valid_answers": valid_solutions_str,
"difficulty": { "difficulty": {
"n": self.config.n, "n": self.config.n,
"num_removed": num_removed, "num_removed": (self.config.min_remove, self.config.max_remove),
}, },
}, },
} }

View file

@ -161,7 +161,9 @@ class RushHourDataset(ProceduralDataset):
"metadata": { "metadata": {
"board_config": board_config, "board_config": board_config,
"min_moves": min_moves, "min_moves": min_moves,
"difficulty": {"min_moves": min_moves}, "difficulty": {
"min_moves": (self.config.min_moves, self.config.max_moves),
},
}, },
} }

View file

@ -65,7 +65,7 @@ class SokobanDataset(ProceduralDataset):
# Make the Sokoban! # Make the Sokoban!
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
gamestr, solution, difficulty = self._generate( gamestr, solution, puzzle_data = self._generate(
rng=rng, rng=rng,
min_w=self.config.min_w, min_w=self.config.min_w,
min_h=self.config.min_h, min_h=self.config.min_h,
@ -93,7 +93,15 @@ Here is your puzzle:
""" """
+ gamestr, + gamestr,
"answer": solution, "answer": solution,
"metadata": {"gamestr": gamestr, "difficulty": difficulty}, "metadata": {
"gamestr": gamestr,
"width": puzzle_data["width"],
"height": puzzle_data["height"],
"difficulty": {
"width": (self.config.min_w, self.config.max_w),
"height": (self.config.min_h, self.config.max_h),
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -216,7 +216,7 @@ class SudokuDataset(ProceduralDataset):
"solution": solved_board, "solution": solved_board,
"num_empty": num_empty, "num_empty": num_empty,
"difficulty": { "difficulty": {
"num_empty": num_empty, "empty": (self.config.min_empty, self.config.max_empty),
}, },
}, },
} }

View file

@ -275,6 +275,9 @@ class HanoiDataset(ProceduralDataset):
"target_peg": target_peg, "target_peg": target_peg,
"auxiliary_pegs": auxiliary_pegs, "auxiliary_pegs": auxiliary_pegs,
"solution_length": len(solution), "solution_length": len(solution),
"difficulty": {
"num_disks": (self.min_disks, self.max_disks),
},
}, },
} }

View file

@ -270,7 +270,13 @@ class TsumegoDataset(ProceduralDataset):
"Specify your move in coordinates (e.g. 'C4' for column C, row 4)" "Specify your move in coordinates (e.g. 'C4' for column C, row 4)"
), ),
"answer": solution_str, "answer": solution_str,
"metadata": {"difficulty": {"board_size": size}, "board": board}, "metadata": {
"board": board,
"board_size": size,
"difficulty": {
"board_size": (self.config.min_board_size, self.config.max_board_size),
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -115,7 +115,9 @@ class SimpleGeometryDataset(ProceduralDataset):
"missing_angle_raw": missing_angle, "missing_angle_raw": missing_angle,
"missing_angle_rounded": missing_angle_rounded, "missing_angle_rounded": missing_angle_rounded,
"total_interior_sum": total_sum, "total_interior_sum": total_sum,
"difficulty": {"sides": n_sides}, "difficulty": {
"sides": (self.config.min_sides, self.config.max_sides),
},
}, },
} }

View file

@ -136,7 +136,11 @@ class CourseScheduleDataset(ProceduralDataset):
"prerequisites": prerequisites, "prerequisites": prerequisites,
"solution": answer, "solution": answer,
"solvable": solvable, "solvable": solvable,
"difficulty": {"num_courses": num_courses}, "difficulty": {
"num_courses": (self.config.min_num_courses, self.config.max_num_courses),
"num_prerequisites": (self.config.min_num_prerequisites, self.config.max_num_prerequisites),
"cycle_length": (self.config.min_cycle_length, self.config.max_cycle_length),
},
}, },
} }

View file

@ -206,7 +206,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"relationship": relationship.value, "relationship": relationship.value,
"family_size": len(family), "family_size": len(family),
"difficulty": { "difficulty": {
"family_size": len(family), "family_size": (self.config.min_family_size, self.config.max_family_size),
}, },
}, },
} }

View file

@ -142,9 +142,10 @@ class LargestIslandDataset(ProceduralDataset):
"grid": grid, "grid": grid,
"solution": answer, "solution": answer,
"difficulty": { "difficulty": {
"rows": rows, "rows": (self.config.min_rows, self.config.max_rows),
"cols": cols, "cols": (self.config.min_cols, self.config.max_cols),
"num_islands": num_islands, "num_islands": (self.config.min_num_islands, self.config.max_num_islands),
"island_size": (self.config.min_island_size, self.config.max_island_size),
}, },
}, },
} }

View file

@ -56,12 +56,12 @@ Buttons:
"question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data), "question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data),
"answer": "".join(puzzle_data["solution"]), "answer": "".join(puzzle_data["solution"]),
"metadata": { "metadata": {
"metadata": {"difficulty": difficulty},
"solution_path": puzzle_data["solution"], "solution_path": puzzle_data["solution"],
"target_value": puzzle_data["target_value"], "target_value": puzzle_data["target_value"],
"buttons": puzzle_data["buttons"], "buttons": puzzle_data["buttons"],
"initial_state": puzzle_data["initial_state"], "initial_state": puzzle_data["initial_state"],
"initial_value": puzzle_data["initial_value"], "initial_value": puzzle_data["initial_value"],
"difficulty": {"difficulty": difficulty},
}, },
} }

View file

@ -162,8 +162,8 @@ class ShortestPathDataset(ProceduralDataset):
"matrix": matrix, "matrix": matrix,
"solution": answer, "solution": answer,
"difficulty": { "difficulty": {
"rows": rows, "rows": (self.config.min_rows, self.config.max_rows),
"cols": cols, "cols": (self.config.min_cols, self.config.max_cols),
}, },
}, },
} }

View file

@ -387,7 +387,7 @@ class CircuitLogicDataset(ProceduralDataset):
"final_gate": final_gate_name, "final_gate": final_gate_name,
"inputs": inputs_list, "inputs": inputs_list,
"difficulty": { "difficulty": {
"terms": num_terms, "terms": (self.config.min_terms, self.config.max_terms),
"inputs": (self.config.min_inputs, self.config.max_inputs), "inputs": (self.config.min_inputs, self.config.max_inputs),
}, },
}, },

View file

@ -221,8 +221,8 @@ class PropositionalLogicDataset(ProceduralDataset):
"complexity": self._measure_complexity(conclusion), "complexity": self._measure_complexity(conclusion),
"example_answer": str(conclusion), "example_answer": str(conclusion),
"difficulty": { "difficulty": {
"vars": num_vars, "vars": (self.config.min_vars, self.config.max_vars),
"statements": num_statements, "statements": (self.config.min_statements, self.config.max_statements),
"complexity": (self.config.min_complexity, self.config.max_complexity), "complexity": (self.config.min_complexity, self.config.max_complexity),
}, },
}, },

View file

@ -346,7 +346,9 @@ class SelfReferenceDataset(ProceduralDataset):
return { return {
"question": puzz_s, "question": puzz_s,
"answer": answer, "answer": answer,
"metadata": {"difficulty": difficulty}, "metadata": {
"difficulty": {"difficulty": difficulty},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -43,8 +43,8 @@ def test_boxnet_items():
assert "initial_state" in item["metadata"] assert "initial_state" in item["metadata"]
# Verify row_num and column_num are within limits # Verify row_num and column_num are within limits
row_num = item["metadata"]["difficulty"]["row_num"] row_num = item["metadata"]["row_num"]
column_num = item["metadata"]["difficulty"]["column_num"] column_num = item["metadata"]["column_num"]
assert 1 <= row_num <= 2, f"row_num {row_num} outside valid range" assert 1 <= row_num <= 2, f"row_num {row_num} outside valid range"
assert 1 <= column_num <= 2, f"column_num {column_num} outside valid range" assert 1 <= column_num <= 2, f"column_num {column_num} outside valid range"
@ -78,8 +78,8 @@ def test_boxnet_grid_sizes():
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
row_num = item["metadata"]["difficulty"]["row_num"] row_num = item["metadata"]["row_num"]
column_num = item["metadata"]["difficulty"]["column_num"] column_num = item["metadata"]["column_num"]
rows_set.add(row_num) rows_set.add(row_num)
columns_set.add(column_num) columns_set.add(column_num)

View file

@ -53,11 +53,15 @@ def test_coach_with_chain_sum():
# Each key should be a tuple of tuples containing difficulty parameters # Each key should be a tuple of tuples containing difficulty parameters
for key in aggregated.scores: for key in aggregated.scores:
assert isinstance(key, tuple) assert isinstance(key, tuple)
# Each inner tuple should be (param_name, value) # Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value))
for param in key: for param in key:
assert isinstance(param, tuple) assert isinstance(param, tuple)
assert param[0] in ("num_terms", "num_digits") assert param[0] in ("num_terms", "num_digits")
assert isinstance(param[1], int) assert (
isinstance(param[1], int)
or (isinstance(param[1], tuple) and len(param[1]) == 2)
and all(isinstance(v, int) for v in param[1])
)
# Test aggregation with last_n # Test aggregation with last_n
last_3 = coach.score_board.aggregate(last_n=3) last_3 = coach.score_board.aggregate(last_n=3)
@ -171,7 +175,7 @@ def test_coach_with_composite():
item = coach[i + 5] # Use different indices item = coach[i + 5] # Use different indices
if "chain_sum" in item["metadata"]["source_dataset"]: if "chain_sum" in item["metadata"]["source_dataset"]:
metadata = item["metadata"] metadata = item["metadata"]
assert metadata["difficulty"]["num_terms"] >= 4 assert metadata["num_terms"] >= 4
def test_grouped_scores_str(): def test_grouped_scores_str():

View file

@ -38,12 +38,12 @@ def test_rearc_items():
assert "input" in meta assert "input" in meta
assert "output" in meta assert "output" in meta
assert "task_id" in meta assert "task_id" in meta
assert "rng" in meta["difficulty"] assert "rng" in meta
assert "pso" in meta["difficulty"] assert "pso" in meta
# Validate difficulty bounds # Validate difficulty bounds
assert config.diff_lb <= meta["difficulty"]["rng"] <= config.diff_ub assert config.diff_lb <= meta["rng"] <= config.diff_ub
assert config.diff_lb <= meta["difficulty"]["pso"] <= config.diff_ub assert config.diff_lb <= meta["pso"] <= config.diff_ub
def test_rearc_solution_validation(): def test_rearc_solution_validation():

View file

@ -124,7 +124,7 @@ def test_score_answer():
# test optimal score for answers, patching each entry # test optimal score for answers, patching each entry
for x in dataset: for x in dataset:
assert len(x["metadata"]["board"]) == x["metadata"]["difficulty"]["board_size"] assert len(x["metadata"]["board"]) == x["metadata"]["board_size"]
assert dataset.score_answer(x["answer"], entry=x) == 1.0 assert dataset.score_answer(x["answer"], entry=x) == 1.0