include ranges rather than sampled values in difficulty metadata dicts (#387)

* update difficulty metadata for logic datasets

* update difficulty metadata for graph datasets

* update difficulty metadata for geometry datasets

* update difficulty metadata for games datasets

* update difficulty metadata for cognition datasets

* update difficulty metadata for arithmetic datasets

* update difficulty metadata for arc datasets

* update difficulty metadata for algorithmic datasets

* update difficulty metadata for algebra datasets

* use tuples

* update tests

* update tests
This commit is contained in:
Oliver Stanley 2025-03-20 09:27:03 +00:00 committed by GitHub
parent b69c35818a
commit 7475a20700
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 304 additions and 126 deletions

View file

@ -124,7 +124,11 @@ In solving equations, please follow these instructions:
"variable": variable,
"degree": degree,
"real_solutions": real_solutions,
"difficulty": {"terms": num_terms, "degree": degree},
"num_terms": num_terms,
"difficulty": {
"terms": (self.config.min_terms, self.config.max_terms),
"degree": (self.config.min_degree, self.config.max_degree),
},
},
}

View file

@ -85,7 +85,10 @@ When performing calculations, please follow these guidelines:
"integrand": str(derivative),
"variable": str(symbol),
"expected_answer_expression": polynomial,
"difficulty": {"terms": num_terms},
"num_terms": num_terms,
"difficulty": {
"terms": (self.config.min_terms, self.config.max_terms),
},
},
}

View file

@ -110,8 +110,8 @@ class BaseConversionDataset(ProceduralDataset):
"source_repr": source_repr,
"target_repr": target_repr,
"difficulty": {
"value": value,
"base": (source_base, target_base),
"base": (self.config.min_base, self.config.max_base),
"value": (self.config.min_value, self.config.max_value),
},
},
}

View file

@ -108,7 +108,10 @@ class BinaryAlternationDataset(ProceduralDataset):
"string": string,
"solution": answer,
"solvable": solvable,
"difficulty": {"n": n},
"n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
},
}

View file

@ -130,8 +130,9 @@ class BinaryMatrixDataset(ProceduralDataset):
"metadata": {
"matrix": matrix,
"solution": answer,
"n": n,
"difficulty": {
"n": n,
"n": (self.config.min_n, self.config.max_n),
"p_zero": self.config.p_zero,
},
},

View file

@ -80,9 +80,10 @@ class CaesarCipherDataset(ProceduralDataset):
"rotation": rotation,
"cipher_text": cipher_text,
"clear_text": sentence,
"num_words": num_words,
"difficulty": {
"rotation": rotation,
"words": num_words,
"words": (self.config.min_words, self.config.max_words),
"rotation": (self.config.min_rotation, self.config.max_rotation),
},
},
}

View file

@ -64,8 +64,9 @@ class CountPrimesDataset(ProceduralDataset):
"end": end,
"primes": primes,
"solution": answer,
"n": (start, end),
"difficulty": {
"n": (start, end),
"n": (self.config.min_n, self.config.max_n),
},
},
}

View file

@ -187,8 +187,7 @@ class CryptarithmDataset(ProceduralDataset):
"digit_to_letter": digit_to_letter,
"letter_to_digit": letter_to_digit,
"difficulty": {
"min_words": self.config.min_words,
"max_words": self.config.max_words,
"words": (self.config.min_words, self.config.max_words),
},
},
}

View file

@ -368,6 +368,13 @@ class GameOfLifeHaltingDataset(ProceduralDataset):
"placed_patterns": placed_patterns,
"simulation_steps": self.config.max_simulation_steps,
"should_oscillate": should_oscillate,
"difficulty": {
"grid_size_x": self.config.grid_size_x,
"grid_size_y": self.config.grid_size_y,
"difficulty": self.config.difficulty,
"num_oscillators": self.config.num_oscillators,
"max_simulation_steps": self.config.max_simulation_steps,
},
},
}

View file

@ -215,7 +215,11 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
"metadata": {
"possible_answer": solution,
"puzzle": puzzle,
"difficulty": {"num_vertices": num_vertices, "num_colors": num_colors},
"num_vertices": num_vertices,
"difficulty": {
"num_vertices": (self.config.min_num_vertices, self.config.max_num_vertices),
"num_colors": num_colors,
},
},
}

View file

@ -117,8 +117,10 @@ class GroupAnagramsDataset(ProceduralDataset):
"metadata": {
"words": words,
"solution": answer,
"anagram_groups": anagram_groups,
"difficulty": {
"anagram_groups": anagram_groups,
"anagram_groups": (self.config.min_anagram_groups, self.config.max_anagram_groups),
"words_per_group": (self.config.min_words_per_group, self.config.max_words_per_group),
},
},
}

View file

@ -110,8 +110,9 @@ class IsomorphicStringsDataset(ProceduralDataset):
"words": [s, t],
"solution": answer,
"solvable": solvable,
"string_length": string_length,
"difficulty": {
"string_length": string_length,
"string_length": (self.config.min_string_length, self.config.max_string_length),
},
},
}

View file

@ -67,7 +67,9 @@ class LetterCountingDataset(ProceduralDataset):
"span_length": span_length,
"target_letter": target_letter,
"span": span,
"difficulty": {"words": span_length},
"difficulty": {
"words": (self.config.min_words, self.config.max_words),
},
},
}

View file

@ -110,8 +110,8 @@ class LetterJumbleDataset(ProceduralDataset):
"original_words": selected_words,
"difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len),
"words": num_words,
"corruption_level": corruption_level,
"words": (self.config.min_words, self.config.max_words),
"corruption_level": (self.config.min_corruption_level, self.config.max_corruption_level),
},
},
}

View file

@ -309,10 +309,13 @@ class ManipulateMatrixDataset(ProceduralDataset):
"matrix": matrix,
"solution": answer,
"operations": operations,
"rows": rows,
"cols": cols,
"num_transforms": num_transforms,
"difficulty": {
"rows": rows,
"cols": cols,
"num_transforms": num_transforms,
"rows": (self.config.min_rows, self.config.max_rows),
"cols": (self.config.min_cols, self.config.max_cols),
"num_transforms": (self.config.min_transforms, self.config.max_transforms),
},
},
}

View file

@ -95,8 +95,9 @@ class NumberFilteringDataset(ProceduralDataset):
"filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs,
"numbers": len(numbers),
"difficulty": {
"numbers": len(numbers),
"numbers": (self.config.min_numbers, self.config.max_numbers),
"decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value),
},

View file

@ -93,8 +93,9 @@ Please follow the instruction below:
"original_numbers": number_strs,
"direction": direction,
"sorted_numbers": answer,
"numbers": count,
"difficulty": {
"numbers": count,
"numbers": (self.config.min_numbers, self.config.max_numbers),
"decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value),
},

View file

@ -69,8 +69,9 @@ class PalindromeDataset(ProceduralDataset):
"metadata": {
"letters": scrambled_letters,
"generated_palindrome": palindrome,
"length": length,
"difficulty": {
"length": length,
"length": (self.config.min_length, self.config.max_length),
},
},
}

View file

@ -140,8 +140,13 @@ class PalindromePartitioningDataset(ProceduralDataset):
"metadata": {
"string": string,
"solution": answer,
"string_len": string_len,
"difficulty": {
"string_len": string_len,
"string_len": (self.config.min_string_len, self.config.max_string_len),
"substring_palindrome_len": (
self.config.min_substring_palindrome_len,
self.config.max_substring_palindrome_len,
),
},
},
}

View file

@ -117,10 +117,13 @@ class PoolMatrixDataset(ProceduralDataset):
"pool_type": pool_type,
"pool_size": pool_size,
"solution": answer.tolist(),
"rows": rows,
"cols": cols,
"pool_size": pool_size,
"difficulty": {
"rows": rows,
"cols": cols,
"pool_size": pool_size,
"rows": (self.config.min_rows, self.config.max_rows),
"cols": (self.config.min_cols, self.config.max_cols),
"pool_size": (self.config.min_pool_size, self.config.max_pool_size),
},
},
}

View file

@ -103,9 +103,11 @@ class RansomNoteDataset(ProceduralDataset):
"magazine": magazine,
"solution": answer,
"solvable": solvable,
"note_length": note_length,
"magazine_length": magazine_length,
"difficulty": {
"note_length": note_length,
"magazine_length": magazine_length,
"note_length": (self.config.min_note_length, self.config.max_note_length),
"magazine_length": (self.config.min_magazine_length, self.config.max_magazine_length),
},
},
}

View file

@ -86,9 +86,10 @@ class RotateMatrixDataset(ProceduralDataset):
"matrix": matrix,
"num_rotations": num_rotations,
"solution": answer,
"n": n,
"difficulty": {
"n": n,
"num_rotations": num_rotations,
"n": (self.config.min_n, self.config.max_n),
"num_rotations": (self.config.min_rotations, self.config.max_rotations),
},
},
}

View file

@ -122,7 +122,10 @@ class RottenOrangesDataset(ProceduralDataset):
"metadata": {
"matrix": matrix,
"solution": answer,
"difficulty": {"n": n},
"n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
},
}

View file

@ -90,7 +90,12 @@ class SentenceReorderingDataset(ProceduralDataset):
return {
"question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence,
"metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}},
"metadata": {
"word_count": word_count,
"difficulty": {
"words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -54,7 +54,9 @@ class SpellBackwardDataset(ProceduralDataset):
"metadata": {
"word": word,
"word_len": len(word),
"difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)},
"difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len),
},
},
}

View file

@ -101,7 +101,7 @@ class SpiralMatrixDataset(ProceduralDataset):
"""Generate a single Spiral Matrix question"""
rng = Random(self.seed + idx)
n = rng.randint(2, self.config.max_n)
n = rng.randint(self.config.min_n, self.config.max_n)
matrix = self._get_matrix(rng, n)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_spiral(matrix)
@ -113,7 +113,10 @@ class SpiralMatrixDataset(ProceduralDataset):
"metadata": {
"matrix": matrix,
"solution": answer,
"difficulty": {"n": n},
"n": n,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
},
}

View file

@ -104,8 +104,9 @@ class StringInsertionDataset(ProceduralDataset):
"metadata": {
"string": string,
"solution": answer,
"string_length": string_length,
"difficulty": {
"string_length": string_length,
"string_length": (self.config.min_string_length, self.config.max_string_length),
},
},
}

View file

@ -183,9 +183,11 @@ class StringManipulationDataset(ProceduralDataset):
"solution": answer,
"states": states,
"selected_rules": [rule for rule, _ in selected_rules],
"string_length": string_length,
"num_rules": num_rules,
"difficulty": {
"string_length": string_length,
"num_rules": num_rules,
"string_length": (self.config.min_string_length, self.config.max_string_length),
"num_rules": (self.config.min_num_rules, self.config.max_num_rules),
},
},
}

View file

@ -127,8 +127,9 @@ class StringSplittingDataset(ProceduralDataset):
"metadata": {
"states": states,
"solution": answer,
"initial_machines": (A_machine, B_machine, C_machine),
"difficulty": {
"initial_machines": (A_machine, B_machine, C_machine),
"initial_machines": (self.config.min_initial_machines, self.config.max_initial_machines),
},
},
}

View file

@ -132,8 +132,9 @@ class StringSynthesisDataset(ProceduralDataset):
"metadata": {
"states": states,
"solution": answer,
"initial_blocks": (A_square, B_square, C_square),
"difficulty": {
"initial_blocks": (A_square, B_square, C_square),
"initial_blocks": (self.config.min_initial_blocks, self.config.max_initial_blocks),
},
},
}

View file

@ -224,8 +224,7 @@ class WordLadderDataset(ProceduralDataset):
"word_length": length,
"chain_length": len(path),
"difficulty": {
"word_length": length,
"chain_length": len(path),
"word_length": (self.config.min_word_length, self.config.max_word_length),
},
},
}

View file

@ -62,7 +62,13 @@ class WordSequenceReversalDataset(ProceduralDataset):
return {
"question": f"{QUESTION_TEMPLATE.format(words=words_str)}",
"answer": answer,
"metadata": {"num_words": num_words, "words": words, "difficulty": {"words": num_words}},
"metadata": {
"num_words": num_words,
"words": words,
"difficulty": {
"words": (self.config.min_words, self.config.max_words),
},
},
}

View file

@ -106,14 +106,16 @@ class WordSortingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer),
"metadata": {
"difficulty": {
"num_words": len(original_words),
"word_length": max(len(word) for word in original_words),
},
"original_words": original_words,
"sorted_words": answer,
"transformed_words": transformed_words,
"direction": direction,
"num_words": len(original_words),
"word_length": max(len(word) for word in original_words),
"difficulty": {
"num_words": (self.config.min_words, self.config.max_words),
"word_length": (self.config.min_word_length, self.config.max_word_length),
},
},
}

View file

@ -117,9 +117,11 @@ class ReArcDataset(ProceduralDataset):
"input": task["input"],
"output": task["output"],
"task_id": task_id,
"rng": rng_difficulty,
"pso": pso_difficulty,
"difficulty": {
"rng": rng_difficulty,
"pso": pso_difficulty,
"rng_difficulty": self.config.rng_difficulty_weights,
"pso_difficulty": self.config.pso_difficulty_weights,
},
},
}

View file

@ -96,7 +96,12 @@ class BasicArithmeticDataset(ProceduralDataset):
"answer": str(result),
"metadata": {
"expression": expression,
"difficulty": {"num_terms": num_terms, "num_digits": num_digits},
"num_terms": num_terms,
"num_digits": num_digits,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
},
}

View file

@ -64,11 +64,13 @@ class ChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"difficulty": {
"num_terms": num_terms,
"num_digits": num_digits,
},
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
},
}

View file

@ -46,7 +46,10 @@ class CountBitsDataset(ProceduralDataset):
"number": number,
"solution": answer,
"binary": binary,
"difficulty": {"n": number},
"n": number,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
},
}

View file

@ -189,9 +189,11 @@ class DecimalArithmeticDataset(ProceduralDataset):
"question": problem_str,
"answer": str(answer),
"metadata": {
"decimal_places": decimal_places,
"num_terms": terms,
"difficulty": {
"decimal_places": decimal_places,
"num_terms": terms,
"decimal_places": (self.config.min_num_decimal_places, self.config.max_num_decimal_places),
"num_terms": (self.config.min_terms, self.config.max_terms),
},
},
}

View file

@ -66,11 +66,14 @@ class DecimalChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"difficulty": {
"num_terms": num_terms,
"num_digits": num_digits,
},
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
"decimal_places": (self.config.min_decimal_places, self.config.max_decimal_places),
},
},
}

View file

@ -124,11 +124,11 @@ class DiceDataset(ProceduralDataset):
"question": puzzle_str,
"answer": answer_str,
"metadata": {
"puzzle": puzzle,
"difficulty": {
"num_dice": self.config.num_dice,
"max_dice_size": self.config.max_dice_size,
},
"puzzle": puzzle,
},
}

View file

@ -120,9 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style,
"factor": factor,
"difficulty": {
"factor": factor,
"value": (simple_num, simple_den),
"value": (self.config.min_value, self.config.max_value),
"factor": (self.config.min_factor, self.config.max_factor),
},
},
}

View file

@ -64,9 +64,10 @@ class GCDDataset(ProceduralDataset):
"metadata": {
"numbers": numbers,
"result": result,
"num_terms": num_terms,
"difficulty": {
"num_terms": num_terms,
"max_value": self.config.max_value,
"num_terms": (self.config.min_numbers, self.config.max_numbers),
"max_value": (self.config.min_value, self.config.max_value),
},
},
}

View file

@ -67,7 +67,7 @@ class LCMDataset(ProceduralDataset):
"numbers": numbers,
"result": result,
"difficulty": {
"numbers": len(numbers),
"numbers": (self.config.min_numbers, self.config.max_numbers),
"value": (self.config.min_value, self.config.max_value),
},
},

View file

@ -118,11 +118,13 @@ class LegCountingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
"answer": str(total_legs),
"metadata": {
"difficulty": {
"num_animals": len(animals),
},
"animals": animals,
"num_animals": len(animals),
"total_legs": total_legs,
"difficulty": {
"num_animals": (self.config.min_animals, self.config.max_animals),
"num_instances": (self.config.min_instances, self.config.max_instances),
},
},
}

View file

@ -98,8 +98,9 @@ class NumberFormatDataset(ProceduralDataset):
"solution": answer,
"formatted_candidates": formatted_candidates,
"size": size,
"num_candidates": num_candidates,
"difficulty": {
"num_candidates": num_candidates,
"num_candidates": (self.config.min_num_candidates, self.config.max_num_candidates),
"n": (self.config.min_n, self.config.max_n),
"min_delta": self.config.max_delta,
},

View file

@ -73,7 +73,14 @@ class PowerFunctionDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer),
"metadata": {"base": base, "exponent": exponent, "solution": answer, "difficulty": {"exponent": exponent}},
"metadata": {
"base": base,
"exponent": exponent,
"solution": answer,
"difficulty": {
"exponent": (self.config.min_exponent, self.config.max_exponent),
},
},
}

View file

@ -83,7 +83,13 @@ class PrimeFactorizationDataset(ProceduralDataset):
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
),
"answer": answer,
"metadata": {"number": number, "factors": factors, "difficulty": {"value": number}},
"metadata": {
"number": number,
"factors": factors,
"difficulty": {
"value": (self.config.min_value, self.config.max_value),
},
},
}

View file

@ -66,11 +66,13 @@ class ProductsDataset(ProceduralDataset):
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result),
"metadata": {
"difficulty": {
"num_terms": num_terms,
"num_digits": num_digits,
},
"expression": expression,
"num_terms": num_terms,
"num_digits": num_digits,
"difficulty": {
"num_terms": (self.config.min_terms, self.config.max_terms),
"num_digits": (self.config.min_digits, self.config.max_digits),
},
},
}

View file

@ -141,7 +141,9 @@ class ColorCubeRotationDataset(ProceduralDataset):
"rotations": [r.value for r in rotations],
"target_side": target_side.value,
"num_rotations": num_rotations,
"difficulty": {"rotations": num_rotations},
"difficulty": {
"rotations": (self.config.min_rotations, self.config.max_rotations),
},
},
}

View file

@ -188,7 +188,9 @@ class FigletFontDataset(ProceduralDataset):
"metadata": {
"font": chosen_font,
"space_letters": self.config.space_letters,
"difficulty": {"word_len": len(word)},
"difficulty": {
"word_len": (self.config.min_word_len, self.config.max_word_len),
},
},
}

View file

@ -140,9 +140,11 @@ class ModuloGridDataset(ProceduralDataset):
"target": target,
"operation": operation,
"difficulty": {
"holes": self.config.max_holes,
"size_x": self.config.size_x,
"size_y": self.config.size_y,
"holes": self.config.max_holes,
"divisor": self.config.max_divisor,
"target": self.config.max_target,
},
},
}

View file

@ -103,7 +103,13 @@ class NeedleHaystackDataset(ProceduralDataset):
return {
"question": full_text,
"answer": stack["needle"][0],
"metadata": {"question": question, "difficulty": {"num_statements": num_statements}},
"metadata": {
"question": question,
"num_statements": num_statements,
"difficulty": {
"num_statements": (self.config.min_num_statements, self.config.max_num_statements),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -195,7 +195,14 @@ class NumberSequenceDataset(ProceduralDataset):
return {
"question": ", ".join(map(str, visible_terms)) + ", ?",
"answer": str(sequence[-1]),
"metadata": {"rule": rule.to_string(), "complexity": complexity, "sequence": sequence},
"metadata": {
"rule": rule.to_string(),
"complexity": complexity,
"sequence": sequence,
"difficulty": {
"max_complexity": self.config.max_complexity,
},
},
}

View file

@ -117,7 +117,14 @@ class RectangleCountDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(puzzle=puzzle),
"answer": str(answer),
"metadata": {"puzzle": puzzle, "solution": answer, "difficulty": {"max_rectangles": target}},
"metadata": {
"puzzle": puzzle,
"solution": answer,
"num_rectangles": target,
"difficulty": {
"max_rectangles": self.config.max_rectangles,
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -110,8 +110,8 @@ class RubiksCubeDataset(ProceduralDataset):
"scramble_moves": " ".join([str(move) for move in scramble_moves]),
"example_correct_answer": actions_string,
"difficulty": {
"scramble_steps": num_steps,
"cube_size": self.config.cube_size,
"scramble_steps": (self.config.min_scramble_steps, self.config.max_scramble_steps),
},
},
}

View file

@ -126,11 +126,14 @@ class BoxnetDataset(ProceduralDataset):
"question": question,
"answer": None,
"metadata": {
"difficulty": {
"row_num": row_num,
"column_num": column_num,
},
"row_num": row_num,
"column_num": column_num,
"initial_state": pg_dict,
"difficulty": {
"row_num": (self.config.min_row_num, self.config.max_row_num),
"column_num": (self.config.min_column_num, self.config.max_column_num),
"box_num": (self.config.min_box_num, self.config.max_box_num),
},
},
}

View file

@ -194,7 +194,10 @@ class EmojiMysteryDataset(ProceduralDataset):
"answer": secret_sentence,
"metadata": {
"emoji": secret_emoji,
"difficulty": {"num_words_in_sentence": len(re.findall(r"\b\w+\b", secret_sentence))},
"num_words_in_sentence": len(re.findall(r"\b\w+\b", secret_sentence)),
"difficulty": {
"num_words_in_sentence": (self.config.min_words_in_sentence, self.config.max_words_in_sentence),
},
},
}

View file

@ -84,7 +84,12 @@ class FutoshikiDataset(ProceduralDataset):
"puzzle": puzzle,
"constraints": constraints,
"solution": solution,
"difficulty": {"board_size": board_size, "difficulty": difficulty},
"board_size": board_size,
"difficulty_rating": difficulty,
"difficulty": {
"board_size": (self.config.min_board_size, self.config.max_board_size),
"difficulty": (self.config.min_difficulty, self.config.max_difficulty),
},
},
}

View file

@ -122,7 +122,9 @@ class MahjongPuzzleDataset(ProceduralDataset):
"metadata": {
"rounds": rounds,
"solution": answer,
"difficulty": {"num_rounds": num_rounds},
"difficulty": {
"num_rounds": (self.config.min_num_rounds, self.config.max_num_rounds),
},
},
}

View file

@ -112,8 +112,8 @@ class MazeDataset(ProceduralDataset):
"wall": self.wall_char,
"path": self.path_char,
"difficulty": {
"dist": dist,
"grid_size": size,
"dist": (self.config.min_dist, self.config.max_dist),
"grid_size": (self.config.min_grid_size, self.config.max_grid_size),
},
},
}

View file

@ -197,7 +197,7 @@ class MiniSudokuDataset(ProceduralDataset):
"solution": solved_board,
"num_empty": num_empty,
"difficulty": {
"empty": num_empty,
"empty": (self.config.min_empty, self.config.max_empty),
},
},
}

View file

@ -137,7 +137,7 @@ class NQueensDataset(ProceduralDataset):
"valid_answers": valid_solutions_str,
"difficulty": {
"n": self.config.n,
"num_removed": num_removed,
"num_removed": (self.config.min_remove, self.config.max_remove),
},
},
}

View file

@ -161,7 +161,9 @@ class RushHourDataset(ProceduralDataset):
"metadata": {
"board_config": board_config,
"min_moves": min_moves,
"difficulty": {"min_moves": min_moves},
"difficulty": {
"min_moves": (self.config.min_moves, self.config.max_moves),
},
},
}

View file

@ -65,7 +65,7 @@ class SokobanDataset(ProceduralDataset):
# Make the Sokoban!
rng = Random(self.seed + idx)
gamestr, solution, difficulty = self._generate(
gamestr, solution, puzzle_data = self._generate(
rng=rng,
min_w=self.config.min_w,
min_h=self.config.min_h,
@ -93,7 +93,15 @@ Here is your puzzle:
"""
+ gamestr,
"answer": solution,
"metadata": {"gamestr": gamestr, "difficulty": difficulty},
"metadata": {
"gamestr": gamestr,
"width": puzzle_data["width"],
"height": puzzle_data["height"],
"difficulty": {
"width": (self.config.min_w, self.config.max_w),
"height": (self.config.min_h, self.config.max_h),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -216,7 +216,7 @@ class SudokuDataset(ProceduralDataset):
"solution": solved_board,
"num_empty": num_empty,
"difficulty": {
"num_empty": num_empty,
"empty": (self.config.min_empty, self.config.max_empty),
},
},
}

View file

@ -275,6 +275,9 @@ class HanoiDataset(ProceduralDataset):
"target_peg": target_peg,
"auxiliary_pegs": auxiliary_pegs,
"solution_length": len(solution),
"difficulty": {
"num_disks": (self.min_disks, self.max_disks),
},
},
}

View file

@ -270,7 +270,13 @@ class TsumegoDataset(ProceduralDataset):
"Specify your move in coordinates (e.g. 'C4' for column C, row 4)"
),
"answer": solution_str,
"metadata": {"difficulty": {"board_size": size}, "board": board},
"metadata": {
"board": board,
"board_size": size,
"difficulty": {
"board_size": (self.config.min_board_size, self.config.max_board_size),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -115,7 +115,9 @@ class SimpleGeometryDataset(ProceduralDataset):
"missing_angle_raw": missing_angle,
"missing_angle_rounded": missing_angle_rounded,
"total_interior_sum": total_sum,
"difficulty": {"sides": n_sides},
"difficulty": {
"sides": (self.config.min_sides, self.config.max_sides),
},
},
}

View file

@ -136,7 +136,11 @@ class CourseScheduleDataset(ProceduralDataset):
"prerequisites": prerequisites,
"solution": answer,
"solvable": solvable,
"difficulty": {"num_courses": num_courses},
"difficulty": {
"num_courses": (self.config.min_num_courses, self.config.max_num_courses),
"num_prerequisites": (self.config.min_num_prerequisites, self.config.max_num_prerequisites),
"cycle_length": (self.config.min_cycle_length, self.config.max_cycle_length),
},
},
}

View file

@ -206,7 +206,7 @@ class FamilyRelationshipsDataset(ProceduralDataset):
"relationship": relationship.value,
"family_size": len(family),
"difficulty": {
"family_size": len(family),
"family_size": (self.config.min_family_size, self.config.max_family_size),
},
},
}

View file

@ -142,9 +142,10 @@ class LargestIslandDataset(ProceduralDataset):
"grid": grid,
"solution": answer,
"difficulty": {
"rows": rows,
"cols": cols,
"num_islands": num_islands,
"rows": (self.config.min_rows, self.config.max_rows),
"cols": (self.config.min_cols, self.config.max_cols),
"num_islands": (self.config.min_num_islands, self.config.max_num_islands),
"island_size": (self.config.min_island_size, self.config.max_island_size),
},
},
}

View file

@ -56,12 +56,12 @@ Buttons:
"question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data),
"answer": "".join(puzzle_data["solution"]),
"metadata": {
"metadata": {"difficulty": difficulty},
"solution_path": puzzle_data["solution"],
"target_value": puzzle_data["target_value"],
"buttons": puzzle_data["buttons"],
"initial_state": puzzle_data["initial_state"],
"initial_value": puzzle_data["initial_value"],
"difficulty": {"difficulty": difficulty},
},
}

View file

@ -162,8 +162,8 @@ class ShortestPathDataset(ProceduralDataset):
"matrix": matrix,
"solution": answer,
"difficulty": {
"rows": rows,
"cols": cols,
"rows": (self.config.min_rows, self.config.max_rows),
"cols": (self.config.min_cols, self.config.max_cols),
},
},
}

View file

@ -387,7 +387,7 @@ class CircuitLogicDataset(ProceduralDataset):
"final_gate": final_gate_name,
"inputs": inputs_list,
"difficulty": {
"terms": num_terms,
"terms": (self.config.min_terms, self.config.max_terms),
"inputs": (self.config.min_inputs, self.config.max_inputs),
},
},

View file

@ -221,8 +221,8 @@ class PropositionalLogicDataset(ProceduralDataset):
"complexity": self._measure_complexity(conclusion),
"example_answer": str(conclusion),
"difficulty": {
"vars": num_vars,
"statements": num_statements,
"vars": (self.config.min_vars, self.config.max_vars),
"statements": (self.config.min_statements, self.config.max_statements),
"complexity": (self.config.min_complexity, self.config.max_complexity),
},
},

View file

@ -346,7 +346,9 @@ class SelfReferenceDataset(ProceduralDataset):
return {
"question": puzz_s,
"answer": answer,
"metadata": {"difficulty": difficulty},
"metadata": {
"difficulty": {"difficulty": difficulty},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:

View file

@ -43,8 +43,8 @@ def test_boxnet_items():
assert "initial_state" in item["metadata"]
# Verify row_num and column_num are within limits
row_num = item["metadata"]["difficulty"]["row_num"]
column_num = item["metadata"]["difficulty"]["column_num"]
row_num = item["metadata"]["row_num"]
column_num = item["metadata"]["column_num"]
assert 1 <= row_num <= 2, f"row_num {row_num} outside valid range"
assert 1 <= column_num <= 2, f"column_num {column_num} outside valid range"
@ -78,8 +78,8 @@ def test_boxnet_grid_sizes():
for i in range(len(dataset)):
item = dataset[i]
row_num = item["metadata"]["difficulty"]["row_num"]
column_num = item["metadata"]["difficulty"]["column_num"]
row_num = item["metadata"]["row_num"]
column_num = item["metadata"]["column_num"]
rows_set.add(row_num)
columns_set.add(column_num)

View file

@ -53,11 +53,15 @@ def test_coach_with_chain_sum():
# Each key should be a tuple of tuples containing difficulty parameters
for key in aggregated.scores:
assert isinstance(key, tuple)
# Each inner tuple should be (param_name, value)
# Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value))
for param in key:
assert isinstance(param, tuple)
assert param[0] in ("num_terms", "num_digits")
assert isinstance(param[1], int)
assert (
isinstance(param[1], int)
or (isinstance(param[1], tuple) and len(param[1]) == 2)
and all(isinstance(v, int) for v in param[1])
)
# Test aggregation with last_n
last_3 = coach.score_board.aggregate(last_n=3)
@ -171,7 +175,7 @@ def test_coach_with_composite():
item = coach[i + 5] # Use different indices
if "chain_sum" in item["metadata"]["source_dataset"]:
metadata = item["metadata"]
assert metadata["difficulty"]["num_terms"] >= 4
assert metadata["num_terms"] >= 4
def test_grouped_scores_str():

View file

@ -38,12 +38,12 @@ def test_rearc_items():
assert "input" in meta
assert "output" in meta
assert "task_id" in meta
assert "rng" in meta["difficulty"]
assert "pso" in meta["difficulty"]
assert "rng" in meta
assert "pso" in meta
# Validate difficulty bounds
assert config.diff_lb <= meta["difficulty"]["rng"] <= config.diff_ub
assert config.diff_lb <= meta["difficulty"]["pso"] <= config.diff_ub
assert config.diff_lb <= meta["rng"] <= config.diff_ub
assert config.diff_lb <= meta["pso"] <= config.diff_ub
def test_rearc_solution_validation():

View file

@ -124,7 +124,7 @@ def test_score_answer():
# test optimal score for answers, patching each entry
for x in dataset:
assert len(x["metadata"]["board"]) == x["metadata"]["difficulty"]["board_size"]
assert len(x["metadata"]["board"]) == x["metadata"]["board_size"]
assert dataset.score_answer(x["answer"], entry=x) == 1.0