diff --git a/GALLERY.md b/GALLERY.md index b8e13c55..7164e6c4 100644 --- a/GALLERY.md +++ b/GALLERY.md @@ -45,9 +45,11 @@ This gallery shows examples from all available datasets using their default conf - [mini_sudoku](#mini_sudoku) - [n_queens](#n_queens) - [number_filtering](#number_filtering) +- [number_format](#number_format) - [number_sequence](#number_sequence) - [number_sorting](#number_sorting) - [palindrome](#palindrome) +- [palindrome_partitioning](#palindrome_partitioning) - [polynomial_equations](#polynomial_equations) - [polynomial_multiplication](#polynomial_multiplication) - [pool_matrix](#pool_matrix) @@ -808,17 +810,71 @@ size = 500 Example tasks: ```` Example 1: -Question: Convert the base-3 number 220020 to binary +Question: Your task is to convert a number between two different bases. + +If the target base is > 10, use lowercase letters a-z for digits above 9. + +Example: +- Input: Convert the base-9 number 440 to base-5 +- Output: 2420 +- Explanation + - First, we convert the base-9 number 440 to base-10: 4 * 9**2 + 4 * 9**1 + 0 * 9**0 = 324 + 36 + 0 = 360 + - Next, we convert the base-10 number 360 to base-5: + - 360 // 5 = 72 remainder 0 + - 72 // 5 = 14 remainder 2 + - 14 // 5 = 2 remainder 4 + - 2 // 5 = 0 remainder 2 + - Reading the remainders in reverse order gives us the base-5 number 2 4 2 0 + - Hence, the final answer is 2420 + +Now, convert the base-3 number 220020 to binary + Answer: 1010001110 Metadata: {'decimal_value': 654, 'source_base': 3, 'target_base': 2, 'source_repr': '220020', 'target_repr': '1010001110'} Example 2: -Question: Convert the base-6 number 103 to base-13 (use lowercase letters a-z for digits above 9) +Question: Your task is to convert a number between two different bases. + +If the target base is > 10, use lowercase letters a-z for digits above 9. + +Example: +- Input: Convert the base-9 number 440 to base-5 +- Output: 2420 +- Explanation + - First, we convert the base-9 number 440 to base-10: 4 * 9**2 + 4 * 9**1 + 0 * 9**0 = 324 + 36 + 0 = 360 + - Next, we convert the base-10 number 360 to base-5: + - 360 // 5 = 72 remainder 0 + - 72 // 5 = 14 remainder 2 + - 14 // 5 = 2 remainder 4 + - 2 // 5 = 0 remainder 2 + - Reading the remainders in reverse order gives us the base-5 number 2 4 2 0 + - Hence, the final answer is 2420 + +Now, convert the base-6 number 103 to base-13 + Answer: 30 Metadata: {'decimal_value': 39, 'source_base': 6, 'target_base': 13, 'source_repr': '103', 'target_repr': '30'} Example 3: -Question: Convert the base-10 number 418 to base-13 (use lowercase letters a-z for digits above 9) +Question: Your task is to convert a number between two different bases. + +If the target base is > 10, use lowercase letters a-z for digits above 9. + +Example: +- Input: Convert the base-9 number 440 to base-5 +- Output: 2420 +- Explanation + - First, we convert the base-9 number 440 to base-10: 4 * 9**2 + 4 * 9**1 + 0 * 9**0 = 324 + 36 + 0 = 360 + - Next, we convert the base-10 number 360 to base-5: + - 360 // 5 = 72 remainder 0 + - 72 // 5 = 14 remainder 2 + - 14 // 5 = 2 remainder 4 + - 2 // 5 = 0 remainder 2 + - Reading the remainders in reverse order gives us the base-5 number 2 4 2 0 + - Hence, the final answer is 2420 + +Now, convert the base-10 number 418 to base-13 + Answer: 262 Metadata: {'decimal_value': 418, 'source_base': 10, 'target_base': 13, 'source_repr': '418', 'target_repr': '262'} @@ -845,17 +901,17 @@ whitespace = single Example tasks: ```` Example 1: -Question: Put your final answer after '=' without additional text. Calculate -5 * -6 = +Question: Calculate -5 * -6. Ensure to report the answer as an integer. Do not add commas to the integer answers reported. Answer: 30 Metadata: {'num_terms': 2, 'num_digits': 1, 'expression': '-5 * -6'} Example 2: -Question: Put your final answer after '=' without additional text. Calculate 965 / 5 = +Question: Calculate 965 / 5. Ensure to report the answer as an integer. Do not add commas to the integer answers reported. Answer: 193 Metadata: {'num_terms': 2, 'num_digits': 3, 'expression': '965 / 5'} Example 3: -Question: Put your final answer after '=' without additional text. Calculate 0 + -2 + -4 * 0 * 3 = +Question: Calculate 0 + -2 + -4 * 0 * 3. Ensure to report the answer as an integer. Do not add commas to the integer answers reported. Answer: -2 Metadata: {'num_terms': 5, 'num_digits': 1, 'expression': '0 + -2 + -4 * 0 * 3'} @@ -874,23 +930,29 @@ difficulty = 1 Example tasks: ```` Example 1: -Question: This is a BF (Brainf*ck) computer program. What is the output? +Question: This is a BF (Brainf*ck) computer program. What is the output? >[-]>[-]<>++++++++++[<+++++++++++>-]<+.-.+++++.--------------.+++++++++++++++.< + +Respond only with the exact output of the program. Answer: onset Metadata: {'bfit_code': '\nint main() {\n print("onset");\n}\n', 'bf_program': '>[-]>[-]<>++++++++++[<+++++++++++>-]<+.-.+++++.--------------.+++++++++++++++.<'} Example 2: -Question: This is a BF (Brainf*ck) computer program. What is the output? +Question: Consider the following BF (Brainf*ck) code. What would it output? >[-]>[-]<>++++++++[<++++++++++++++>-]<.-----------.+++++++++++++.---------------.+++++.< + +Provide only the exact output of the code. Answer: perch Metadata: {'bfit_code': '\nint main() {\n print("perch");\n}\n', 'bf_program': '>[-]>[-]<>++++++++[<++++++++++++++>-]<.-----------.+++++++++++++.---------------.+++++.<'} Example 3: -Question: This is a BF (Brainf*ck) computer program. What is the output? +Question: This is a BF (Brainf*ck) computer program. What is the output? >[-]>[-]<>+++++++++[<+++++++++++++>-]<.-------.----------.+.+++++++++++++.< + +Respond only with the exact output of the program. Answer: under Metadata: {'bfit_code': '\nint main() {\n print("under");\n}\n', 'bf_program': '>[-]>[-]<>+++++++++[<+++++++++++++>-]<.-------.----------.+.+++++++++++++.<'} @@ -1731,21 +1793,21 @@ Example tasks: Example 1: Question: John is married to Isabella. They have a child called Edward. Edward is married to Victoria. -What is Isabella to Edward? +What is Isabella to Edward? Respond only with the word that describes their relationship. Answer: mother Metadata: {'person1': 'Isabella', 'person2': 'Edward', 'relationship': 'mother', 'family_size': 4} Example 2: Question: Henry is married to Karen. They have a child called Sebastian. Sebastian is married to Eleanor. -What relation is Henry to Karen? +What relation is Henry to Karen? Answer with a single word. Answer: husband Metadata: {'person1': 'Henry', 'person2': 'Karen', 'relationship': 'husband', 'family_size': 4} Example 3: Question: Liam is married to Nova. They have a child called Noah. Noah is married to Charlotte. They have a child called Patricia. Joseph is married to Lisa. They have a child called Charlotte. -What is Liam to Noah? +What is Liam to Noah? Respond only with the word that describes their relationship. Answer: father Metadata: {'person1': 'Liam', 'person2': 'Noah', 'relationship': 'father', 'family_size': 7} @@ -1948,7 +2010,7 @@ size = 500 Example tasks: ```` Example 1: -Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) +Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. Let your answer(array of array be on a single line). (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) [[0,1,0,1,1,0,0,0,1,0], [1,0,0,1,0,1,1,1,1,1], @@ -1964,7 +2026,7 @@ Answer: [[0,1,0,0,0,0,0,0,0,0],[1,1,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0],[0,0, Metadata: {'grid_size_x': 10, 'grid_size_y': 10, 'filled_cells': 100, 'simulation_steps': 1} Example 2: -Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) +Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. Let your answer(array of array be on a single line). (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) [[1,1,1,1,1,1,0,1,1,1], [0,0,1,1,1,1,1,1,1,1], @@ -1980,7 +2042,7 @@ Answer: [[0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0],[0,1,0,0,0,0,0,0,0,0],[0,1, Metadata: {'grid_size_x': 10, 'grid_size_y': 10, 'filled_cells': 100, 'simulation_steps': 1} Example 3: -Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) +Question: What will this Game of Life board look like after 1 steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. Let your answer(array of array be on a single line). (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]]) [[0,1,0,1,1,1,1,0,0,1], [0,1,0,0,1,1,1,0,1,1], @@ -2050,7 +2112,7 @@ Vertices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] Edges: [(0, 2), (0, 3), (0, 4), (0, 8), (1, 2), (1, 3), (1, 5), (1, 6), (1, 9), (2, 5), (2, 8), (2, 9), (3, 5), (3, 6), (3, 7), (4, 9), (6, 9), (7, 8), (7, 9), (8, 9)] Possible colors: [1, 2, 3, 4] -Return your solution as a JSON map of verteces to colors. (For example: {0: 1, 1: 2, 2: 3}) +Return your solution as a JSON map of vertices to colors. (For example: {0: 1, 1: 2, 2: 3}) Answer: None Metadata: {'possible_answer': {0: 1, 1: 1, 2: 2, 3: 2, 4: 2, 5: 3, 6: 3, 7: 1, 8: 3, 9: 4}, 'puzzle': {'vertices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'edges': [(0, 2), (0, 3), (0, 4), (0, 8), (1, 2), (1, 3), (1, 5), (1, 6), (1, 9), (2, 5), (2, 8), (2, 9), (3, 5), (3, 6), (3, 7), (4, 9), (6, 9), (7, 8), (7, 9), (8, 9)], 'num_colors': 4, 'color_options': [1, 2, 3, 4]}} @@ -2062,7 +2124,7 @@ Vertices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] Edges: [(0, 1), (0, 3), (0, 9), (1, 3), (1, 8), (2, 4), (2, 5), (3, 6), (3, 7), (3, 8), (4, 6), (4, 9), (6, 7), (7, 9)] Possible colors: [1, 2, 3, 4] -Return your solution as a JSON map of verteces to colors. (For example: {0: 1, 1: 2, 2: 3}) +Return your solution as a JSON map of vertices to colors. (For example: {0: 1, 1: 2, 2: 3}) Answer: None Metadata: {'possible_answer': {0: 1, 1: 2, 2: 1, 3: 3, 4: 2, 5: 2, 6: 1, 7: 2, 8: 1, 9: 3}, 'puzzle': {'vertices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'edges': [(0, 1), (0, 3), (0, 9), (1, 3), (1, 8), (2, 4), (2, 5), (3, 6), (3, 7), (3, 8), (4, 6), (4, 9), (6, 7), (7, 9)], 'num_colors': 4, 'color_options': [1, 2, 3, 4]}} @@ -2074,7 +2136,7 @@ Vertices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] Edges: [(0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (1, 5), (1, 8), (1, 9), (2, 5), (2, 6), (2, 7), (2, 9), (3, 6), (3, 7), (4, 5), (4, 6), (4, 7), (4, 8), (5, 8), (6, 9)] Possible colors: [1, 2, 3, 4] -Return your solution as a JSON map of verteces to colors. (For example: {0: 1, 1: 2, 2: 3}) +Return your solution as a JSON map of vertices to colors. (For example: {0: 1, 1: 2, 2: 3}) Answer: None Metadata: {'possible_answer': {0: 1, 1: 1, 2: 1, 3: 1, 4: 2, 5: 3, 6: 3, 7: 3, 8: 4, 9: 2}, 'puzzle': {'vertices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'edges': [(0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (1, 5), (1, 8), (1, 9), (2, 5), (2, 6), (2, 7), (2, 9), (3, 6), (3, 7), (4, 5), (4, 6), (4, 7), (4, 8), (5, 8), (6, 9)], 'num_colors': 4, 'color_options': [1, 2, 3, 4]}} @@ -2699,17 +2761,83 @@ size = 500 Example tasks: ```` Example 1: -Question: Unscramble these words: ew hsall eb ebla ot puodrce +Question: Your task is to unsramble words in a sentence. + +For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words. + +The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.). + +Example: +- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt +- Output: meanderings Why any child can do that +- Explanation + - We unscramble each of the words independently. + - raendgmeins -> meanderings + - yWh -> Why + - nya -> any + - hilcd -> child + - anc -> can + - od -> do + - hatt -> that + - The final answer is: meanderings Why any child can do that + - Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added. + +Now, unscramble these words: ew hsall eb ebla ot puodrce + Answer: we shall be able to produce Metadata: {'num_words': 6, 'corruption_level': 0.12000860417813355, 'scrambled_words': ['ew', 'hsall', 'eb', 'ebla', 'ot', 'puodrce'], 'original_words': ['we', 'shall', 'be', 'able', 'to', 'produce']} Example 2: -Question: Unscramble these words: ni oiurnalmsj Well Cahs +Question: Your task is to unsramble words in a sentence. + +For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words. + +The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.). + +Example: +- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt +- Output: meanderings Why any child can do that +- Explanation + - We unscramble each of the words independently. + - raendgmeins -> meanderings + - yWh -> Why + - nya -> any + - hilcd -> child + - anc -> can + - od -> do + - hatt -> that + - The final answer is: meanderings Why any child can do that + - Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added. + +Now, unscramble these words: ni oiurnalmsj Well Cahs + Answer: in journalism Well Cash Metadata: {'num_words': 4, 'corruption_level': 0.3288673442377109, 'scrambled_words': ['ni', 'oiurnalmsj', 'Well', 'Cahs'], 'original_words': ['in', 'journalism', 'Well', 'Cash']} Example 3: -Question: Unscramble these words: dear rchAdbali keep no nSice yrstyedae atnhks ot oyu rheet si a gain fo sucrbbisesr rM +Question: Your task is to unsramble words in a sentence. + +For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words. + +The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.). + +Example: +- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt +- Output: meanderings Why any child can do that +- Explanation + - We unscramble each of the words independently. + - raendgmeins -> meanderings + - yWh -> Why + - nya -> any + - hilcd -> child + - anc -> can + - od -> do + - hatt -> that + - The final answer is: meanderings Why any child can do that + - Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added. + +Now, unscramble these words: dear rchAdbali keep no nSice yrstyedae atnhks ot oyu rheet si a gain fo sucrbbisesr rM + Answer: dear Archibald keep on Since yesterday thanks to you there is a gain of subscribers Mr Metadata: {'num_words': 16, 'corruption_level': 0.516016391169858, 'scrambled_words': ['dear', 'rchAdbali', 'keep', 'no', 'nSice', 'yrstyedae', 'atnhks', 'ot', 'oyu', 'rheet', 'si', 'a', 'gain', 'fo', 'sucrbbisesr', 'rM'], 'original_words': ['dear', 'Archibald', 'keep', 'on', 'Since', 'yesterday', 'thanks', 'to', 'you', 'there', 'is', 'a', 'gain', 'of', 'subscribers', 'Mr']} @@ -2881,23 +3009,35 @@ size = 500 Example tasks: ```` Example 1: -Question: Solve this 4x4 Mini Sudoku puzzle: -_ _ _ _ -_ _ _ _ +Question: In 4x4 Mini Sudoku: +- Each row must contain each number from 1-4 exactly once +- Each column must contain each number 1-4 exactly once +- Each 2x2 subgrid must contain each number 1-4 exactly once +Solve this 4x4 Mini Sudoku puzzle: +4 _ _ _ +_ 3 _ _ _ 1 3 _ -_ 4 _ 1 +_ _ _ _ +Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows. + Answer: 4 2 1 3 1 3 4 2 2 1 3 4 3 4 2 1 -Metadata: {'puzzle': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 3, 0], [0, 4, 0, 1]], 'solution': [[4, 2, 1, 3], [1, 3, 4, 2], [2, 1, 3, 4], [3, 4, 2, 1]], 'num_empty': 12} +Metadata: {'puzzle': [[4, 0, 0, 0], [0, 3, 0, 0], [0, 1, 3, 0], [0, 0, 0, 0]], 'solution': [[4, 2, 1, 3], [1, 3, 4, 2], [2, 1, 3, 4], [3, 4, 2, 1]], 'num_empty': 12} Example 2: -Question: Solve this 4x4 Mini Sudoku puzzle: +Question: In 4x4 Mini Sudoku: +- Each row must contain each number from 1-4 exactly once +- Each column must contain each number 1-4 exactly once +- Each 2x2 subgrid must contain each number 1-4 exactly once +Solve this 4x4 Mini Sudoku puzzle: 3 _ _ _ _ _ 4 _ 4 2 _ _ _ _ _ 4 +Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows. + Answer: 3 4 1 2 2 1 4 3 4 2 3 1 @@ -2905,16 +3045,22 @@ Answer: 3 4 1 2 Metadata: {'puzzle': [[3, 0, 0, 0], [0, 0, 4, 0], [4, 2, 0, 0], [0, 0, 0, 4]], 'solution': [[3, 4, 1, 2], [2, 1, 4, 3], [4, 2, 3, 1], [1, 3, 2, 4]], 'num_empty': 11} Example 3: -Question: Solve this 4x4 Mini Sudoku puzzle: +Question: In 4x4 Mini Sudoku: +- Each row must contain each number from 1-4 exactly once +- Each column must contain each number 1-4 exactly once +- Each 2x2 subgrid must contain each number 1-4 exactly once +Solve this 4x4 Mini Sudoku puzzle: _ _ _ _ 1 3 4 _ -3 1 2 4 -4 _ _ _ +3 _ 2 4 +4 _ _ 1 +Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows. + Answer: 2 4 1 3 1 3 4 2 3 1 2 4 4 2 3 1 -Metadata: {'puzzle': [[0, 0, 0, 0], [1, 3, 4, 0], [3, 1, 2, 4], [4, 0, 0, 0]], 'solution': [[2, 4, 1, 3], [1, 3, 4, 2], [3, 1, 2, 4], [4, 2, 3, 1]], 'num_empty': 8} +Metadata: {'puzzle': [[0, 0, 0, 0], [1, 3, 4, 0], [3, 0, 2, 4], [4, 0, 0, 1]], 'solution': [[2, 4, 1, 3], [1, 3, 4, 2], [3, 1, 2, 4], [4, 2, 3, 1]], 'num_empty': 8} ```` @@ -3097,6 +3243,68 @@ Metadata: {'original_numbers': ['4', '-64.7', '-42.1', '-77', '-79.9640', '37.76 ```` +### number_format +Generates Count Bits exercises with configurable difficulty + +Default configuration: +```python +max_num_candidates = 5 +min_n = 1000 +max_n = 1000000000 +max_delta = 1000 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Your task is to pick the largest/smallest number out of several options. + +Example +- Input: Pick the largest number of the following candidates: 857575.23 8.975554e+05 887,555.62 +- Output: 8.975554e+05 +- Explanation: + - Sorting the numbers written in various notations we get: 857575.23 < 887,555.62 < 8.975554e+05 + - Therefore, the largest number is 8.975554e+05 + +Now, pick the largest number of the following candidates: 25011730.212000 25011280.271000 + +Answer: 25011730.212 +Metadata: {'candidates': [25011730.212, 25011280.271], 'solution': 25011730.212, 'formatted_candidates': ['25011730.212000', '25011280.271000'], 'size': 'largest'} + +Example 2: +Question: Your task is to pick the largest/smallest number out of several options. + +Example +- Input: Pick the largest number of the following candidates: 857575.23 8.975554e+05 887,555.62 +- Output: 8.975554e+05 +- Explanation: + - Sorting the numbers written in various notations we get: 857575.23 < 887,555.62 < 8.975554e+05 + - Therefore, the largest number is 8.975554e+05 + +Now, pick the largest number of the following candidates: 286,084,894.213 286,085,419.581 + +Answer: 286085419.581 +Metadata: {'candidates': [286084894.213, 286085419.581], 'solution': 286085419.581, 'formatted_candidates': ['286,084,894.213', '286,085,419.581'], 'size': 'largest'} + +Example 3: +Question: Your task is to pick the largest/smallest number out of several options. + +Example +- Input: Pick the largest number of the following candidates: 857575.23 8.975554e+05 887,555.62 +- Output: 8.975554e+05 +- Explanation: + - Sorting the numbers written in various notations we get: 857575.23 < 887,555.62 < 8.975554e+05 + - Therefore, the largest number is 8.975554e+05 + +Now, pick the largest number of the following candidates: 520020968.942000 520021372.170000 5.200202022530000e+08 520020728.080000 520020548.078000 + +Answer: 520021372.16999996 +Metadata: {'candidates': [520020968.942, 520021372.16999996, 520020202.25299996, 520020728.08, 520020548.07799995], 'solution': 520021372.16999996, 'formatted_candidates': ['520020968.942000', '520021372.170000', '5.200202022530000e+08', '520020728.080000', '520020548.078000'], 'size': 'largest'} + +```` + ### number_sequence Generates number sequence completion tasks with dynamic pattern generation @@ -3251,6 +3459,82 @@ Metadata: {'letters': ['n', 'j', 'n', 'j', 'd', 'j', 's', 's', 'd'], 'generated_ ```` +### palindrome_partitioning +Generates Palindrome Partitioning exercises with configurable difficulty + +Default configuration: +```python +min_string_len = 5 +max_string_len = 15 +max_substring_palindome_len = 5 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Given a string, partition it such that every substring is a palindrome. + +A palindrome is a word that reads the same backward as forward. + +You may return all possible palindrome partitioning in any order. + +Example: +- Input: Partition the following string into palindromes: aab +- Output: [["a","a","b"],["aa","b"]] +- Explanation: + - One way to partition the string is "a" | "a" | "b", where each substring is a palindrome. + - Another way to partition the string is "aa" | "b", where again each substring is a palindrome. + - Therefore, the final result is a list of the two palindrome partitions. + +Partition the following string into palindromes: agegvckakcgnnrw + +Answer: [["a", "g", "e", "g", "v", "c", "k", "a", "k", "c", "g", "n", "n", "r", "w"], ["a", "g", "e", "g", "v", "c", "k", "a", "k", "c", "g", "nn", "r", "w"], ["a", "g", "e", "g", "v", "c", "kak", "c", "g", "n", "n", "r", "w"], ["a", "g", "e", "g", "v", "c", "kak", "c", "g", "nn", "r", "w"], ["a", "g", "e", "g", "v", "ckakc", "g", "n", "n", "r", "w"], ["a", "g", "e", "g", "v", "ckakc", "g", "nn", "r", "w"], ["a", "geg", "v", "c", "k", "a", "k", "c", "g", "n", "n", "r", "w"], ["a", "geg", "v", "c", "k", "a", "k", "c", "g", "nn", "r", "w"], ["a", "geg", "v", "c", "kak", "c", "g", "n", "n", "r", "w"], ["a", "geg", "v", "c", "kak", "c", "g", "nn", "r", "w"], ["a", "geg", "v", "ckakc", "g", "n", "n", "r", "w"], ["a", "geg", "v", "ckakc", "g", "nn", "r", "w"]] +Metadata: {'string': 'agegvckakcgnnrw', 'solution': [['a', 'g', 'e', 'g', 'v', 'c', 'k', 'a', 'k', 'c', 'g', 'n', 'n', 'r', 'w'], ['a', 'g', 'e', 'g', 'v', 'c', 'k', 'a', 'k', 'c', 'g', 'nn', 'r', 'w'], ['a', 'g', 'e', 'g', 'v', 'c', 'kak', 'c', 'g', 'n', 'n', 'r', 'w'], ['a', 'g', 'e', 'g', 'v', 'c', 'kak', 'c', 'g', 'nn', 'r', 'w'], ['a', 'g', 'e', 'g', 'v', 'ckakc', 'g', 'n', 'n', 'r', 'w'], ['a', 'g', 'e', 'g', 'v', 'ckakc', 'g', 'nn', 'r', 'w'], ['a', 'geg', 'v', 'c', 'k', 'a', 'k', 'c', 'g', 'n', 'n', 'r', 'w'], ['a', 'geg', 'v', 'c', 'k', 'a', 'k', 'c', 'g', 'nn', 'r', 'w'], ['a', 'geg', 'v', 'c', 'kak', 'c', 'g', 'n', 'n', 'r', 'w'], ['a', 'geg', 'v', 'c', 'kak', 'c', 'g', 'nn', 'r', 'w'], ['a', 'geg', 'v', 'ckakc', 'g', 'n', 'n', 'r', 'w'], ['a', 'geg', 'v', 'ckakc', 'g', 'nn', 'r', 'w']]} + +Example 2: +Question: Given a string, partition it such that every substring is a palindrome. + +A palindrome is a word that reads the same backward as forward. + +You may return all possible palindrome partitioning in any order. + +Example: +- Input: Partition the following string into palindromes: aab +- Output: [["a","a","b"],["aa","b"]] +- Explanation: + - One way to partition the string is "a" | "a" | "b", where each substring is a palindrome. + - Another way to partition the string is "aa" | "b", where again each substring is a palindrome. + - Therefore, the final result is a list of the two palindrome partitions. + +Partition the following string into palindromes: sesjj + +Answer: [["s", "e", "s", "j", "j"], ["s", "e", "s", "jj"], ["ses", "j", "j"], ["ses", "jj"]] +Metadata: {'string': 'sesjj', 'solution': [['s', 'e', 's', 'j', 'j'], ['s', 'e', 's', 'jj'], ['ses', 'j', 'j'], ['ses', 'jj']]} + +Example 3: +Question: Given a string, partition it such that every substring is a palindrome. + +A palindrome is a word that reads the same backward as forward. + +You may return all possible palindrome partitioning in any order. + +Example: +- Input: Partition the following string into palindromes: aab +- Output: [["a","a","b"],["aa","b"]] +- Explanation: + - One way to partition the string is "a" | "a" | "b", where each substring is a palindrome. + - Another way to partition the string is "aa" | "b", where again each substring is a palindrome. + - Therefore, the final result is a list of the two palindrome partitions. + +Partition the following string into palindromes: owfwofaafsd + +Answer: [["o", "w", "f", "w", "o", "f", "a", "a", "f", "s", "d"], ["o", "w", "f", "w", "o", "f", "aa", "f", "s", "d"], ["o", "w", "f", "w", "o", "faaf", "s", "d"], ["o", "wfw", "o", "f", "a", "a", "f", "s", "d"], ["o", "wfw", "o", "f", "aa", "f", "s", "d"], ["o", "wfw", "o", "faaf", "s", "d"], ["owfwo", "f", "a", "a", "f", "s", "d"], ["owfwo", "f", "aa", "f", "s", "d"], ["owfwo", "faaf", "s", "d"]] +Metadata: {'string': 'owfwofaafsd', 'solution': [['o', 'w', 'f', 'w', 'o', 'f', 'a', 'a', 'f', 's', 'd'], ['o', 'w', 'f', 'w', 'o', 'f', 'aa', 'f', 's', 'd'], ['o', 'w', 'f', 'w', 'o', 'faaf', 's', 'd'], ['o', 'wfw', 'o', 'f', 'a', 'a', 'f', 's', 'd'], ['o', 'wfw', 'o', 'f', 'aa', 'f', 's', 'd'], ['o', 'wfw', 'o', 'faaf', 's', 'd'], ['owfwo', 'f', 'a', 'a', 'f', 's', 'd'], ['owfwo', 'f', 'aa', 'f', 's', 'd'], ['owfwo', 'faaf', 's', 'd']]} + +```` + ### polynomial_equations Generates random polynomial equations of degree in [min_degree, max_degree]. - The polynomial is formed by summing random terms of the form: coeff * x^exponent. @@ -3321,11 +3605,13 @@ min_terms = 2 max_terms = 4 min_value = 1 max_value = 100 -min_degree = 1 +min_degree = 0 max_degree = 3 min_polynomials = 2 max_polynomials = 3 -single_variable = True +variables = ('x', 'y', 'z') +allow_cross_variable_product = False +allow_multivariate_polynomials = False operators = ('+', '-') seed = 42 size = 500 @@ -3334,31 +3620,31 @@ size = 500 Example tasks: ```` Example 1: -Question: Calculate the following: (65*x - 72)*(105*x - 125) +Question: Calculate the following: (-95*z**3 + 18*z)*(-12*z**2 + 78*z - 104) In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. -Answer: 6825*x**2 - 15685*x + 9000 -Metadata: {'polynomial_expr': '(65*x - 72)*(105*x - 125)', 'single_variable': True, 'result': '6825*x**2 - 15685*x + 9000'} +Answer: 1140*z**5 - 7410*z**4 + 9664*z**3 + 1404*z**2 - 1872*z +Metadata: {'polynomial_expr': '(-95*z**3 + 18*z)*(-12*z**2 + 78*z - 104)', 'result': '1140*z**5 - 7410*z**4 + 9664*z**3 + 1404*z**2 - 1872*z', 'variables': [z]} Example 2: -Question: Calculate the following: (-9*x**2 - 28*x)*(86*x**2 - 2*x - 13) +Question: Simplify this expression: (-49*x**3 + 77*x + 8)*(8*x**3 - 163*x**2 - 49)*(16*x**3 + 74*x + 98) In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. -Answer: -774*x**4 - 2390*x**3 + 173*x**2 + 364*x -Metadata: {'polynomial_expr': '(-9*x**2 - 28*x)*(86*x**2 - 2*x - 13)', 'single_variable': True, 'result': '-774*x**4 - 2390*x**3 + 173*x**2 + 364*x'} +Answer: -6272*x**9 + 127792*x**8 - 19152*x**7 + 391246*x**6 + 807446*x**5 - 746364*x**4 - 1091196*x**3 - 406994*x**2 - 398762*x - 38416 +Metadata: {'polynomial_expr': '(-49*x**3 + 77*x + 8)*(8*x**3 - 163*x**2 - 49)*(16*x**3 + 74*x + 98)', 'result': '-6272*x**9 + 127792*x**8 - 19152*x**7 + 391246*x**6 + 807446*x**5 - 746364*x**4 - 1091196*x**3 - 406994*x**2 - 398762*x - 38416', 'variables': [x]} Example 3: -Question: Calculate the following: (43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29) +Question: Calculate the following: (29*y**2 - 49*y)*(21*y**3 + 49) In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. -Answer: -19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x -Metadata: {'polynomial_expr': '(43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29)', 'single_variable': True, 'result': '-19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x'} +Answer: 609*y**5 - 1029*y**4 + 1421*y**2 - 2401*y +Metadata: {'polynomial_expr': '(29*y**2 - 49*y)*(21*y**3 + 49)', 'result': '609*y**5 - 1029*y**4 + 1421*y**2 - 2401*y', 'variables': [y]} ```` @@ -4070,9 +4356,40 @@ size = 500 Example tasks: ```` Example 1: -Question: How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. +Question: Your task is to count how many rectangles are present in an ASCII grid. - +Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. + +Example: +- Input: How many rectangles are in the grid below? + + #### + # # + #### + + + + + + + + + + + ######### + # █## + # █ # + ########█ # + # # + ### +- Output: 3 +- Explanation: + - The first rectangle is the 3x4 rectangle in the top right. + - The other two rectangles are overlapping in the bottom left corner. + - Therefore, the final answer is 3. + +Now, it's your turn. How many rectangles do you see in the grid below? + @@ -4152,13 +4469,46 @@ Question: How many rectangles do you see? Single rectangles are outlined with a + Answer: 2 +Metadata: {'puzzle': ' \n \n \n \n \n \n \n \n \n \n \n \n \n ################################################## \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n ################################################## \n \n \n \n \n ###################################### \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n ###################################### \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n', 'solution': 2} Example 2: -Question: How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. +Question: Your task is to count how many rectangles are present in an ASCII grid. - +Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. + +Example: +- Input: How many rectangles are in the grid below? + + #### + # # + #### + + + + + + + + + + + ######### + # █## + # █ # + ########█ # + # # + ### +- Output: 3 +- Explanation: + - The first rectangle is the 3x4 rectangle in the top right. + - The other two rectangles are overlapping in the bottom left corner. + - Therefore, the final answer is 3. + +Now, it's your turn. How many rectangles do you see in the grid below? + @@ -4239,12 +4589,45 @@ Question: How many rectangles do you see? Single rectangles are outlined with a + Answer: 1 +Metadata: {'puzzle': ' \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n ############ \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n ############ \n \n \n \n \n \n \n \n', 'solution': 1} Example 3: -Question: How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. +Question: Your task is to count how many rectangles are present in an ASCII grid. - +Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. + +Example: +- Input: How many rectangles are in the grid below? + + #### + # # + #### + + + + + + + + + + + ######### + # █## + # █ # + ########█ # + # # + ### +- Output: 3 +- Explanation: + - The first rectangle is the 3x4 rectangle in the top right. + - The other two rectangles are overlapping in the bottom left corner. + - Therefore, the final answer is 3. + +Now, it's your turn. How many rectangles do you see in the grid below? + @@ -4325,7 +4708,9 @@ Question: How many rectangles do you see? Single rectangles are outlined with a ####################### ########################### + Answer: 7 +Metadata: {'puzzle': ' \n \n \n \n \n \n \n \n \n ######################### \n # # \n # # \n # # \n # # \n # ############ \n # ## # \n # ## # \n # ## # \n # ## # \n # ## # \n #####█#######################██#########█# \n # # ## ## \n # # ## ## \n # # ## ## \n # # ## ## \n # # ## ## \n # # ## ## \n # # ## ## \n # # ## ## \n #####█#######################██#########█# \n # ## # \n # ## # \n # ## # \n # ## # \n # ## # \n # ## # \n # ## # \n # ########## ## # \n # # # ############ \n # # # # \n # ########## # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n # # \n ######################### \n \n \n \n \n \n \n \n \n \n \n ####################### \n # # \n # # \n # # \n # # \n # # \n # # \n # ######█### \n # # # # \n # ######█### \n # # ########################### \n # # # # \n # # # # \n ####################### ########################### \n \n', 'solution': 7} ```` @@ -5417,69 +5802,72 @@ Example tasks: ```` Example 1: Question: Solve this Sudoku puzzle: -4 _ _ _ 5 2 _ 3 _ -_ _ 3 4 6 _ _ _ _ -6 1 2 _ _ 8 4 _ _ -1 _ _ _ _ _ 7 9 5 -3 _ _ 7 1 _ _ 2 6 -7 _ _ 5 _ _ _ _ 3 -2 _ _ _ 7 5 _ _ _ -_ 3 _ _ 4 1 _ _ _ -_ _ _ 2 8 _ _ _ 4 -Answer: 4 7 8 1 5 2 6 3 9 -5 9 3 4 6 7 2 8 1 -6 1 2 3 9 8 4 5 7 -1 2 4 8 3 6 7 9 5 -3 5 9 7 1 4 8 2 6 -7 8 6 5 2 9 1 4 3 -2 4 1 9 7 5 3 6 8 -8 3 5 6 4 1 9 7 2 -9 6 7 2 8 3 5 1 4 -Metadata: {'puzzle': [[4, 0, 0, 0, 5, 2, 0, 3, 0], [0, 0, 3, 4, 6, 0, 0, 0, 0], [6, 1, 2, 0, 0, 8, 4, 0, 0], [1, 0, 0, 0, 0, 0, 7, 9, 5], [3, 0, 0, 7, 1, 0, 0, 2, 6], [7, 0, 0, 5, 0, 0, 0, 0, 3], [2, 0, 0, 0, 7, 5, 0, 0, 0], [0, 3, 0, 0, 4, 1, 0, 0, 0], [0, 0, 0, 2, 8, 0, 0, 0, 4]], 'solution': [[4, 7, 8, 1, 5, 2, 6, 3, 9], [5, 9, 3, 4, 6, 7, 2, 8, 1], [6, 1, 2, 3, 9, 8, 4, 5, 7], [1, 2, 4, 8, 3, 6, 7, 9, 5], [3, 5, 9, 7, 1, 4, 8, 2, 6], [7, 8, 6, 5, 2, 9, 1, 4, 3], [2, 4, 1, 9, 7, 5, 3, 6, 8], [8, 3, 5, 6, 4, 1, 9, 7, 2], [9, 6, 7, 2, 8, 3, 5, 1, 4]], 'num_empty': 48} +4 _ _ _ 9 2 _ 3 _ +_ _ 3 4 6 _ _ _ 7 +6 1 2 _ _ 7 8 _ _ +2 _ _ _ _ _ 7 9 1 +8 _ _ 7 1 _ _ 5 6 +1 _ _ 5 _ _ _ _ 3 +9 _ 4 _ 7 1 _ _ _ +_ 8 _ _ _ _ _ _ _ +_ _ _ 9 8 _ _ _ 4 +Respond with only your answer, formatted as the puzzle, a 9x9 grid with numbers separated by spaces, and rows separated by newlines. +Answer: 4 7 8 1 9 2 6 3 5 +5 9 3 4 6 8 1 2 7 +6 1 2 3 5 7 8 4 9 +2 4 5 8 3 6 7 9 1 +8 3 9 7 1 4 2 5 6 +1 6 7 5 2 9 4 8 3 +9 5 4 2 7 1 3 6 8 +3 8 1 6 4 5 9 7 2 +7 2 6 9 8 3 5 1 4 +Metadata: {'puzzle': [[4, 0, 0, 0, 9, 2, 0, 3, 0], [0, 0, 3, 4, 6, 0, 0, 0, 7], [6, 1, 2, 0, 0, 7, 8, 0, 0], [2, 0, 0, 0, 0, 0, 7, 9, 1], [8, 0, 0, 7, 1, 0, 0, 5, 6], [1, 0, 0, 5, 0, 0, 0, 0, 3], [9, 0, 4, 0, 7, 1, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 9, 8, 0, 0, 0, 4]], 'solution': [[4, 7, 8, 1, 9, 2, 6, 3, 5], [5, 9, 3, 4, 6, 8, 1, 2, 7], [6, 1, 2, 3, 5, 7, 8, 4, 9], [2, 4, 5, 8, 3, 6, 7, 9, 1], [8, 3, 9, 7, 1, 4, 2, 5, 6], [1, 6, 7, 5, 2, 9, 4, 8, 3], [9, 5, 4, 2, 7, 1, 3, 6, 8], [3, 8, 1, 6, 4, 5, 9, 7, 2], [7, 2, 6, 9, 8, 3, 5, 1, 4]], 'num_empty': 48} Example 2: Question: Solve this Sudoku puzzle: _ _ _ 1 3 2 6 4 5 -_ 4 _ 7 _ _ _ 9 1 -_ _ 1 8 _ 9 _ _ _ -_ 8 9 _ _ _ 7 5 4 +_ 4 _ 8 5 _ _ 9 _ +_ _ 1 9 _ 7 _ _ _ +_ 8 9 6 _ _ 7 5 4 _ 3 _ 4 _ 1 9 8 _ -4 6 _ 5 9 _ 1 2 3 -5 _ 4 9 1 7 3 _ _ -9 7 6 _ 8 4 5 1 _ +4 6 _ 5 9 _ 2 3 1 +5 _ 4 7 1 9 3 _ _ +9 7 6 _ _ 4 5 1 _ 8 _ 3 _ _ _ 4 7 _ +Respond with only your answer, formatted as the puzzle, a 9x9 grid with numbers separated by spaces, and rows separated by newlines. Answer: 7 9 8 1 3 2 6 4 5 -3 4 2 7 5 6 8 9 1 -6 5 1 8 4 9 2 3 7 +3 4 2 8 5 6 1 9 7 +6 5 1 9 4 7 8 2 3 1 8 9 6 2 3 7 5 4 2 3 5 4 7 1 9 8 6 -4 6 7 5 9 8 1 2 3 -5 2 4 9 1 7 3 6 8 +4 6 7 5 9 8 2 3 1 +5 2 4 7 1 9 3 6 8 9 7 6 3 8 4 5 1 2 8 1 3 2 6 5 4 7 9 -Metadata: {'puzzle': [[0, 0, 0, 1, 3, 2, 6, 4, 5], [0, 4, 0, 7, 0, 0, 0, 9, 1], [0, 0, 1, 8, 0, 9, 0, 0, 0], [0, 8, 9, 0, 0, 0, 7, 5, 4], [0, 3, 0, 4, 0, 1, 9, 8, 0], [4, 6, 0, 5, 9, 0, 1, 2, 3], [5, 0, 4, 9, 1, 7, 3, 0, 0], [9, 7, 6, 0, 8, 4, 5, 1, 0], [8, 0, 3, 0, 0, 0, 4, 7, 0]], 'solution': [[7, 9, 8, 1, 3, 2, 6, 4, 5], [3, 4, 2, 7, 5, 6, 8, 9, 1], [6, 5, 1, 8, 4, 9, 2, 3, 7], [1, 8, 9, 6, 2, 3, 7, 5, 4], [2, 3, 5, 4, 7, 1, 9, 8, 6], [4, 6, 7, 5, 9, 8, 1, 2, 3], [5, 2, 4, 9, 1, 7, 3, 6, 8], [9, 7, 6, 3, 8, 4, 5, 1, 2], [8, 1, 3, 2, 6, 5, 4, 7, 9]], 'num_empty': 34} +Metadata: {'puzzle': [[0, 0, 0, 1, 3, 2, 6, 4, 5], [0, 4, 0, 8, 5, 0, 0, 9, 0], [0, 0, 1, 9, 0, 7, 0, 0, 0], [0, 8, 9, 6, 0, 0, 7, 5, 4], [0, 3, 0, 4, 0, 1, 9, 8, 0], [4, 6, 0, 5, 9, 0, 2, 3, 1], [5, 0, 4, 7, 1, 9, 3, 0, 0], [9, 7, 6, 0, 0, 4, 5, 1, 0], [8, 0, 3, 0, 0, 0, 4, 7, 0]], 'solution': [[7, 9, 8, 1, 3, 2, 6, 4, 5], [3, 4, 2, 8, 5, 6, 1, 9, 7], [6, 5, 1, 9, 4, 7, 8, 2, 3], [1, 8, 9, 6, 2, 3, 7, 5, 4], [2, 3, 5, 4, 7, 1, 9, 8, 6], [4, 6, 7, 5, 9, 8, 2, 3, 1], [5, 2, 4, 7, 1, 9, 3, 6, 8], [9, 7, 6, 3, 8, 4, 5, 1, 2], [8, 1, 3, 2, 6, 5, 4, 7, 9]], 'num_empty': 34} Example 3: Question: Solve this Sudoku puzzle: -_ _ 1 2 3 _ _ _ 9 -3 _ _ 1 8 5 6 7 2 -_ _ _ 4 9 6 1 _ _ -1 _ 5 7 _ _ 9 2 _ -_ 4 _ _ 5 9 7 1 6 -9 _ 6 _ 1 _ 4 5 3 -_ _ 3 9 7 _ 2 8 4 -_ _ 2 6 4 _ _ 9 1 -_ 1 _ 5 2 8 3 _ _ -Answer: 5 6 1 2 3 7 8 4 9 -3 9 4 1 8 5 6 7 2 -8 2 7 4 9 6 1 3 5 -1 3 5 7 6 4 9 2 8 -2 4 8 3 5 9 7 1 6 -9 7 6 8 1 2 4 5 3 -6 5 3 9 7 1 2 8 4 -7 8 2 6 4 3 5 9 1 -4 1 9 5 2 8 3 6 7 -Metadata: {'puzzle': [[0, 0, 1, 2, 3, 0, 0, 0, 9], [3, 0, 0, 1, 8, 5, 6, 7, 2], [0, 0, 0, 4, 9, 6, 1, 0, 0], [1, 0, 5, 7, 0, 0, 9, 2, 0], [0, 4, 0, 0, 5, 9, 7, 1, 6], [9, 0, 6, 0, 1, 0, 4, 5, 3], [0, 0, 3, 9, 7, 0, 2, 8, 4], [0, 0, 2, 6, 4, 0, 0, 9, 1], [0, 1, 0, 5, 2, 8, 3, 0, 0]], 'solution': [[5, 6, 1, 2, 3, 7, 8, 4, 9], [3, 9, 4, 1, 8, 5, 6, 7, 2], [8, 2, 7, 4, 9, 6, 1, 3, 5], [1, 3, 5, 7, 6, 4, 9, 2, 8], [2, 4, 8, 3, 5, 9, 7, 1, 6], [9, 7, 6, 8, 1, 2, 4, 5, 3], [6, 5, 3, 9, 7, 1, 2, 8, 4], [7, 8, 2, 6, 4, 3, 5, 9, 1], [4, 1, 9, 5, 2, 8, 3, 6, 7]], 'num_empty': 33} +_ _ 1 9 2 _ _ _ 3 +3 _ _ 1 7 5 8 2 6 +_ _ _ 4 3 6 1 _ _ +1 _ 5 7 _ _ 9 3 _ +_ 4 _ _ 5 9 7 1 8 +7 _ 9 _ 1 _ 6 4 5 +_ _ 3 5 9 _ 2 8 4 +_ _ 2 6 8 _ _ 9 1 +_ 5 _ 2 4 1 3 _ _ +Respond with only your answer, formatted as the puzzle, a 9x9 grid with numbers separated by spaces, and rows separated by newlines. +Answer: 5 6 1 9 2 8 4 7 3 +3 9 4 1 7 5 8 2 6 +8 2 7 4 3 6 1 5 9 +1 8 5 7 6 4 9 3 2 +2 4 6 3 5 9 7 1 8 +7 3 9 8 1 2 6 4 5 +6 1 3 5 9 7 2 8 4 +4 7 2 6 8 3 5 9 1 +9 5 8 2 4 1 3 6 7 +Metadata: {'puzzle': [[0, 0, 1, 9, 2, 0, 0, 0, 3], [3, 0, 0, 1, 7, 5, 8, 2, 6], [0, 0, 0, 4, 3, 6, 1, 0, 0], [1, 0, 5, 7, 0, 0, 9, 3, 0], [0, 4, 0, 0, 5, 9, 7, 1, 8], [7, 0, 9, 0, 1, 0, 6, 4, 5], [0, 0, 3, 5, 9, 0, 2, 8, 4], [0, 0, 2, 6, 8, 0, 0, 9, 1], [0, 5, 0, 2, 4, 1, 3, 0, 0]], 'solution': [[5, 6, 1, 9, 2, 8, 4, 7, 3], [3, 9, 4, 1, 7, 5, 8, 2, 6], [8, 2, 7, 4, 3, 6, 1, 5, 9], [1, 8, 5, 7, 6, 4, 9, 3, 2], [2, 4, 6, 3, 5, 9, 7, 1, 8], [7, 3, 9, 8, 1, 2, 6, 4, 5], [6, 1, 3, 5, 9, 7, 2, 8, 4], [4, 7, 2, 6, 8, 3, 5, 9, 1], [9, 5, 8, 2, 4, 1, 3, 6, 7]], 'num_empty': 33} ```` @@ -5561,7 +5949,7 @@ Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6, Example 2: Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM. Answer: 02:38 -Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 16, 9, 44), 'end_time': datetime.datetime(2025, 2, 16, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'} +Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 19, 9, 44), 'end_time': datetime.datetime(2025, 2, 19, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'} Example 3: Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days. @@ -5749,22 +6137,22 @@ Example tasks: ```` Example 1: Question: Transform the word ladder 'HAND' to 'GLEE' by changing one letter at a time. - Provide your answer as a comma-separated sequence of uppercase letters without spaces. - Each step must be a valid English word. +Provide your answer as a comma-separated sequence of uppercase letters without spaces. +Each step must be a valid English word. Answer: HAND,HARD,HERD,HEED,FEED,FLED,FLEE,GLEE Metadata: {'start_word': 'HAND', 'end_word': 'GLEE', 'word_length': 4, 'chain_length': 8} Example 2: Question: Transform the word ladder 'JAZZ' to 'DORM' by changing one letter at a time. - Provide your answer as a comma-separated sequence of uppercase letters without spaces. - Each step must be a valid English word. +Provide your answer as a comma-separated sequence of uppercase letters without spaces. +Each step must be a valid English word. Answer: JAZZ,JIZZ,FIZZ,FUZZ,FUZE,FAZE,FARE,FORE,FORM,DORM Metadata: {'start_word': 'JAZZ', 'end_word': 'DORM', 'word_length': 4, 'chain_length': 10} Example 3: Question: Transform the word ladder 'SNOG' to 'SUQS' by changing one letter at a time. - Provide your answer as a comma-separated sequence of uppercase letters without spaces. - Each step must be a valid English word. +Provide your answer as a comma-separated sequence of uppercase letters without spaces. +Each step must be a valid English word. Answer: SNOG,SNOW,SHOW,SHEW,SHES,SUES,SUQS Metadata: {'start_word': 'SNOG', 'end_word': 'SUQS', 'word_length': 4, 'chain_length': 7} diff --git a/examples/veRL/launch_on_2gpu_server.sh b/examples/veRL/launch_on_2gpu_server.sh new file mode 100755 index 00000000..4f2efc46 --- /dev/null +++ b/examples/veRL/launch_on_2gpu_server.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +export N_GPUS=2 +export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct +export ROLLOUT_TP_SIZE=2 +export EXPERIMENT_NAME=chain_sum_llama +export VLLM_ATTENTION_BACKEND=XFORMERS + +bash ./train_grpo_server.sh diff --git a/examples/veRL/main_ppo_custom_reward_server.py b/examples/veRL/main_ppo_custom_reward_server.py new file mode 100644 index 00000000..0f20be1d --- /dev/null +++ b/examples/veRL/main_ppo_custom_reward_server.py @@ -0,0 +1,344 @@ +# This example is an adapted version of Bytedance's code: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py +import os +from typing import Dict, List, Optional + +import hydra +import ray +import torch +import verl.utils.torch_functional as verl_F +from omegaconf import OmegaConf, open_dict +from torch.utils.data import DataLoader, Dataset +from transformers import PreTrainedTokenizer +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.dataset.rl_dataset import collate_fn +from verl.utils.model import compute_position_id_with_mask + +import reasoning_gym +import reasoning_gym.utils +from reasoning_gym.utils import extract_answer +from tools.server.models import AnswerItem, BatchEntry, ExperimentCreate + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + dataset_name: str, + seed: int, + size: int, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + return_raw_chat: bool = False, + server_url: str = "http://localhost:8000", + api_key: Optional[str] = None, + batch_size: int = 32, + ): + from tools.cli.rgc.client import RGClient + + self.tokenizer = tokenizer + self.dataset_name = dataset_name + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + self.return_raw_chat = return_raw_chat + self.size = size + self.batch_size = batch_size + + # Initialize client and create experiment if needed + self.client = RGClient(base_url=server_url, api_key=api_key) + + # Check if experiment exists, create if not + experiments = self.client.list_experiments() + if dataset_name not in experiments.experiments: + config = ExperimentCreate( + name=dataset_name, + size=size, + seed=seed, + datasets={dataset_name: {"weight": 1.0, "config": {"seed": seed, "size": size}}}, + ) + self.client.create_experiment(dataset_name, config) + + # Cache for batches + self._batch_cache: dict[int, List[BatchEntry]] = {} + + def __len__(self) -> int: + return self.size + + def _get_batch(self, batch_idx: int) -> List[BatchEntry]: + """Fetch or retrieve cached batch""" + if batch_idx not in self._batch_cache: + base_index = batch_idx * self.batch_size + response = self.client.get_batch(self.dataset_name, base_index=base_index, batch_size=self.batch_size) + self._batch_cache[batch_idx] = response.entries + + # # Basic cache management - keep only last N batches + # if len(self._batch_cache) > 10: + # oldest_batch = min(self._batch_cache.keys()) + # del self._batch_cache[oldest_batch] + + return self._batch_cache[batch_idx] + + def __getitem__(self, index): + # Get batch containing this index + batch_idx = index // self.batch_size + + batch = self._get_batch(batch_idx) + entry = batch[index % self.batch_size] + + # Format chat/prompt + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": entry.question}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + # Tokenize + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict = { + "data_source": "reasoning_gym/" + self.dataset_name, + "input_ids": input_ids[0], + "attention_mask": attention_mask[0], + "position_ids": position_ids[0], + "entry_id": entry.entry_id, + "metadata": entry.metadata, + "index": index, + } + + # Add raw chat if requested + if self.return_raw_chat: + row_dict["raw_prompt"] = chat + + return row_dict + + +class RayPPOTrainerCustom(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + dataset_name: str = "chain_sum", + dataset_size: int = 10000, + ): + self.dataset_name = dataset_name + self.dataset_size = dataset_size + + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] + rg_api_key = os.getenv("REASONING_GYM_API_KEY", "your-secret-key") + self.train_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=1, + size=self.dataset_size, + developer_prompt=developer_prompt, + api_key=rg_api_key, + ) + + self.val_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=2, + size=self.dataset_size, + developer_prompt=developer_prompt, + api_key=rg_api_key, + ) + + train_reward_fn = lambda data: self._score_output(data, num_examine=0) + val_reward_fn = lambda data: self._score_output(data, num_examine=1) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + train_reward_fn, + val_reward_fn, + ) + + def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + # Prepare batch of answers to score + answer_items = [] + valid_response_lengths = [] + sequences_strs = [] + + for i in range(len(data)): + data_item = data[i] + + # Get prompt and response + prompt_ids = data_item.batch["prompts"] + prompt_length = prompt_ids.shape[-1] + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + valid_response_lengths.append(valid_response_length) + + # Decode full sequence + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + sequences_strs.append(sequences_str) + + # Extract answer and prepare scoring item + found_answer = extract_answer(sequences_str, tag_name="answer") + + index = data_item.non_tensor_batch["index"] + entry_id = self.train_dataset[index]["entry_id"] + # print( + # "found_answer", + # entry_id, + # found_answer, + # ) + + answer_items.append(AnswerItem(entry_id=entry_id, answer=found_answer)) + + # Score all answers in one request + response = self.train_dataset.client.score_outputs(self.train_dataset.dataset_name, answer_items) + # print("response", response) + + # Fill reward tensor + for i, (score, valid_response_length) in enumerate(zip(response.scores, valid_response_lengths)): + reward_tensor[i, valid_response_length - 1] = score + + if i < num_examine: + print(f"reward={score}, seq={sequences_strs[i]}") + + return reward_tensor + + def _create_dataloader(self): + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=False, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=False, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + +@ray.remote +def main_task(config): + # print initial config + from pprint import pprint + + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainerCustom( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + + ray.get(main_task.remote(config)) + + +if __name__ == "__main__": + main() diff --git a/examples/veRL/train_grpo_server.sh b/examples/veRL/train_grpo_server.sh new file mode 100644 index 00000000..34b956ad --- /dev/null +++ b/examples/veRL/train_grpo_server.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -x + +python3 -u main_ppo_custom_reward_server.py \ + algorithm.adv_estimator=grpo \ + data.train_files=$DATA_DIR/train.parquet \ + data.val_files=$DATA_DIR/test.parquet \ + data.train_batch_size=32 \ + data.val_batch_size=32 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_chain_sum_grpo' \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.n_gpus_per_node=$N_GPUS \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=100 \ + trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log diff --git a/pyproject.toml b/pyproject.toml index 07aa57e4..4bba76fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,22 @@ license = "Apache-2.0" license-files = ["LICENSE*"] [project.optional-dependencies] -test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "httpx>=0.27.0" +] +server = [ + "fastapi>=0.109.0", + "uvicorn>=0.27.0", + "pydantic-settings>=2.1.0", +] +cli = [ + "typer>=0.9.0", + "rich>=13.7.0", + "pyyaml>=6.0.1", + "httpx>=0.27.0", +] [project.urls] "Homepage" = "https://github.com/open-thought/reasoning-gym" @@ -40,12 +55,19 @@ test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] [tool.hatch.build] -packages = ["reasoning_gym"] -include = [ - "reasoning_gym/**/*.py", - "reasoning_gym/**/*.txt", - "reasoning_gym/**/levels/*", +packages = [ + "reasoning_gym", + "tools.cli.rgc" ] +include = [ + "reasoning_gym/**/*.py", + "reasoning_gym/**/*.txt", + "reasoning_gym/**/levels/*", + "tools/cli/rgc/**/*.py" +] + +[project.scripts] +rgc = "tools.cli.rgc.main:main" [tool.black] line-length = 120 diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py index 2afc6d1f..9db6beb7 100644 --- a/reasoning_gym/algorithmic/caesar_cipher.py +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -73,7 +73,7 @@ class CaesarCipherDataset(ProceduralDataset): cipher_text = self._caesar_encrypt(sentence, rotation) return { - "question": f"Decrypt this Caesar cipher text: {cipher_text}", + "question": f"Decrypt this Caesar cipher text: {cipher_text}. Provide only the decrypted text as your final answer.", "answer": sentence, "metadata": {"rotation": rotation, "cipher_text": cipher_text, "clear_text": sentence}, } diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 495a79c5..2e8cf322 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,12 +6,14 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset +from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset +from .number_format import NumberFormatConfig, NumberFormatDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset @@ -46,4 +48,6 @@ __all__ = [ "CountBitsDataset", "DiceConfig", "DiceDataset", + "NumberFormatConfig", + "NumberFormatDataset", ] diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 05c19779..c90a4abb 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -27,10 +27,6 @@ class ChainSumConfig: assert self.min_digits > 0, "min_digits must be positive" assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" - # Validate digit ranges make sense - if self.min_digits > 1: - assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" - class ChainSumDataset(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py new file mode 100644 index 00000000..da920c9d --- /dev/null +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -0,0 +1,157 @@ +import random +from dataclasses import dataclass +from decimal import Decimal +from typing import Any, Dict, Optional + +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class DecimalChainSumConfig: + """Configuration for decimal chain sum task generation""" + + min_terms: int = 2 + max_terms: int = 6 + min_digits: int = 1 + max_digits: int = 4 + min_decimal_places: int = 1 + max_decimal_places: int = 4 + allow_negation: bool = False + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + assert self.min_decimal_places >= 0, "min_decimal_places must be non-negative" + assert self.max_decimal_places >= self.min_decimal_places, "max_decimal_places must be >= min_decimal_places" + + +class DecimalChainSumDataset(ProceduralDataset): + """Generates simple decimal arithmetic tasks using only + and - operators""" + + def __init__(self, config: DecimalChainSumConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single decimal chain sum task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted arithmetic expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + + rng = random.Random(self.seed + idx) + + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) + + # Calculate value ranges based on number of digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(rng, num_terms, min_value, max_value) + + return { + "question": f"State the final answer to the following arithmetic problem: {expression} =", + "answer": str(result), + "metadata": { + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, + "expression": expression, + }, + } + + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, Decimal]: + """Generate a single decimal chain sum task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers + min_decimal_places: Minimum number of decimal places + max_decimal_places: Maximum number of decimal places + + Returns: + Tuple of (expression string, result Decimal) + """ + + # Convert constants to Decimal + constants = [ + Decimal( + str( + rng.randint(-max_value, max_value) + if self.config.allow_negation + else rng.randint(min_value, max_value) + ) + ) + for _ in range(num_terms) + ] + + # Generate decimal places for each term + decimal_places = [ + rng.randint(self.config.min_decimal_places, self.config.max_decimal_places) for _ in range(num_terms) + ] + + # Add decimal parts using Decimal for precise arithmetic + for i in range(num_terms): + min_val = 0 if decimal_places[i] == 0 else 10 ** (decimal_places[i] - 1) + max_val = (10 ** decimal_places[i]) - 1 + decimal_part = Decimal(str(rng.randint(min_val, max_val))) / Decimal(str(10 ** decimal_places[i])) + constants[i] += decimal_part + + operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)] + + expression_parts = [] + result = constants[0] + + expression_parts.append(f"{constants[0]:.{decimal_places[0]}f}") + for i, op in enumerate(operators): + c = constants[i + 1] + expression_parts.append(op) + expression_parts.append(f"{c:.{decimal_places[i+1]}f}") + + if op == "+": + result += c + else: # op == "-" + result -= c + + expression = " ".join(expression_parts) + result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}")) + return expression, result + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """Score the answer by comparing decimal values instead of strings. + Args: + answer: The answer to score + entry: The entry containing the oracle answer + + Returns: + 1.0 for exact numerical match, 0.01 otherwise + """ + if answer is None or len(answer.strip()) == 0: + return 0.0 + + try: + student_answer = Decimal(answer.strip()) + oracle_answer = Decimal(entry["answer"]) + + return 1.0 if student_answer == oracle_answer else 0.01 + except (ValueError, TypeError, ArithmeticError): + return 0.01 + + +register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig) diff --git a/reasoning_gym/arithmetic/number_format.py b/reasoning_gym/arithmetic/number_format.py new file mode 100644 index 00000000..e03d2bdc --- /dev/null +++ b/reasoning_gym/arithmetic/number_format.py @@ -0,0 +1,106 @@ +"""Choose largest number out of several represented in various formats.""" + +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Your task is to pick the largest/smallest number out of several options. + +Example +- Input: Pick the largest number of the following candidates: 857575.23 8.975554e+05 887,555.62 +- Output: 8.975554e+05 +- Explanation: + - Sorting the numbers written in various notations we get: 857575.23 < 887,555.62 < 8.975554e+05 + - Therefore, the largest number is 8.975554e+05 + +Now, pick the {size} number of the following candidates: {numbers} +""" + + +@dataclass +class NumberFormatConfig: + """Configuration for Count Bits dataset generation""" + + max_num_candidates: int = 5 # Maximum number of candidates + min_n: float = 1_000 # Lower bound for the numbers + max_n: float = 1_000_000_000 # Upper bound for the numbers + max_delta: int = 1_000 + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 2 <= self.max_num_candidates, "max_num_candidates must be at least 2" + assert 1 <= self.min_n, "min_n must be at least 1" + assert self.min_n < self.max_n, "min_n must be less than max_n" + assert 1 <= self.max_delta, "max_delta must be at least 1" + + +class NumberFormatDataset(ProceduralDataset): + """Generates Count Bits exercises with configurable difficulty""" + + def __init__(self, config: NumberFormatConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_candidates(self, rng: Random, num_candidates: int) -> list: + """Generate a list of candidates""" + base = round(rng.uniform(self.config.min_n, self.config.max_n), 3) + candidates = [base] + for _ in range(num_candidates - 1): + delta = round(rng.uniform(-self.config.max_delta, self.config.max_delta), 3) + candidates.append(base + delta) + return candidates + + def _transform_candidates(self, rng: Random, candidates: list[float]) -> list[str]: + """Randomly apply different number formats to the candidates""" + output = [] + for candidate in candidates: + format_type = rng.choice(["standard", "english", "scientific"]) + if format_type == "standard": + output.append(f"{candidate:f}") + elif format_type == "english": + output.append(f"{candidate:,}") + elif format_type == "scientific": + output.append(f"{candidate:.15e}") + return output + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["metadata"]["solution"] + if answer is not None and len(answer) > 0: + try: + answer = float(answer.strip().replace(",", "")) + if abs(answer - oracle_answer) < 1e-2: + return 1.0 + return 0.01 + except: + return 0.0 + return 0.0 + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Bits question""" + rng = Random(self.seed + idx) + + num_candidates = rng.randint(2, self.config.max_num_candidates) + candidates = self._get_candidates(rng, num_candidates) + formatted_candidates = self._transform_candidates(rng, candidates) + + size = rng.choice(["largest", "smallest"]) + answer = max(candidates) if size == "largest" else min(candidates) + + return { + "question": QUESTION_TEMPLATE.format(numbers=" ".join(formatted_candidates), size=size), + "answer": str(answer), + "metadata": { + "candidates": candidates, + "solution": answer, + "formatted_candidates": formatted_candidates, + "size": size, + }, + } + + +register_dataset("number_format", NumberFormatDataset, NumberFormatConfig) diff --git a/reasoning_gym/coaching/experiment.py b/reasoning_gym/coaching/experiment.py new file mode 100644 index 00000000..d3a9e00f --- /dev/null +++ b/reasoning_gym/coaching/experiment.py @@ -0,0 +1,36 @@ +"""Experiment class combining dataset, scoreboard and curriculum.""" + +from dataclasses import dataclass +from typing import Optional + +from ..composite import CompositeConfig, CompositeDataset +from ..version_manager import DatasetVersionManager +from .coach import ScoreBoard + + +@dataclass +class Experiment: + """ + An experiment combines a dataset with scoring and curriculum management. + + Attributes: + name: Unique identifier for the experiment + dataset: The composite dataset for generating examples + score_board: Tracks performance metrics + config: The configuration used to create the dataset + version_manager: Manages dataset versions for scoring + """ + + name: str + dataset: CompositeDataset + score_board: ScoreBoard + config: CompositeConfig + version_manager: DatasetVersionManager + + @classmethod + def create(cls, name: str, config: CompositeConfig) -> "Experiment": + """Create a new experiment from a configuration.""" + version_manager = DatasetVersionManager() + dataset = CompositeDataset(config, version_manager=version_manager) + score_board = ScoreBoard() + return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager) diff --git a/reasoning_gym/coaching/registry.py b/reasoning_gym/coaching/registry.py new file mode 100644 index 00000000..5d7fdd1a --- /dev/null +++ b/reasoning_gym/coaching/registry.py @@ -0,0 +1,34 @@ +"""Registry for managing active experiments.""" + +from typing import Dict, List, Optional + +from ..composite import CompositeConfig +from .experiment import Experiment + + +class ExperimentRegistry: + """Singleton registry for managing active experiments.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._experiments = {} + return cls._instance + + def register_experiment(self, name: str, config: CompositeConfig) -> None: + """Register a new experiment with the given name and configuration.""" + self._experiments[name] = Experiment.create(name, config) + + def get_experiment(self, name: str) -> Optional[Experiment]: + """Get an experiment by name.""" + return self._experiments.get(name) + + def list_experiments(self) -> List[str]: + """List all registered experiment names.""" + return list(self._experiments.keys()) + + def remove_experiment(self, name: str) -> bool: + """Remove an experiment by name. Returns True if removed, False if not found.""" + return bool(self._experiments.pop(name, None)) diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 42069423..4b8dc30a 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -185,8 +185,28 @@ class ColorCubeRotationDataset(ProceduralDataset): # Ask question story_parts.append(f"\nWhat is now the color of the {target_side.value} side of the cube?") + story_parts.append(f"Provide only the color as your final answer.") return "\n".join(story_parts) + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + reward = 0.0 + metadata = entry["metadata"] + if answer is not None: + try: + answer_formatted = answer.lower() + solved = answer_formatted == metadata["answer"] + if solved: + reward = 1.0 + elif metadata["answer"] in answer_formatted: + reward = 0.25 + elif len(answer.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig) diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index 2050ddd1..b30151fb 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace from random import Random from typing import Any, Dict, List, Optional @@ -6,6 +6,7 @@ import yaml from .dataset import ProceduralDataset from .factory import create_dataset, register_dataset +from .version_manager import DatasetVersionManager @dataclass @@ -37,6 +38,11 @@ class CompositeConfig: assert self.datasets, "Must specify at least one dataset" assert len(self.datasets) > 0, "Must specify at least one dataset" + # Check for duplicate dataset names + dataset_names = [ds.name for ds in self.datasets] + if len(dataset_names) != len(set(dataset_names)): + raise ValueError("Duplicate dataset names are not allowed in CompositeDataset") + # Validate each dataset spec for ds in self.datasets: ds.validate() @@ -57,13 +63,14 @@ class CompositeConfig: class CompositeDataset(ProceduralDataset): """A dataset that combines multiple datasets with weighted sampling""" - def __init__(self, config: CompositeConfig): + def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None): super().__init__(config=config, seed=config.seed, size=config.size) + self.version_manager = version_manager + self.dataset_versions = {} # dataset_name -> version_id # Initialize sub-datasets with incremented seeds self.datasets = {} self.weights = [] - total_weight = 0.0 for i, ds_spec in enumerate(config.datasets): # Create dataset with derived seed @@ -73,12 +80,18 @@ class CompositeDataset(ProceduralDataset): if "size" not in ds_config: ds_config["size"] = self.size - self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config) - total_weight += ds_spec.weight - self.weights.append(ds_spec.weight) + if ds_spec.weight < 0: + raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative") - # Normalize weights - self.weights = [w / total_weight for w in self.weights] + dataset = create_dataset(ds_spec.name, **ds_config) + self.datasets[ds_spec.name] = dataset + + # Register version if tracking enabled + if version_manager is not None: + version_id = version_manager.register_dataset(ds_spec.name, dataset) + self.dataset_versions[ds_spec.name] = version_id + + self.weights.append(ds_spec.weight) # Store unnormalized weights directly self.dataset_names = [ds.name for ds in config.datasets] def __getitem__(self, idx: int) -> dict: @@ -98,6 +111,13 @@ class CompositeDataset(ProceduralDataset): item["metadata"]["source_dataset"] = dataset_name item["metadata"]["source_index"] = idx + # Add version info if tracking enabled + if self.version_manager is not None: + version_id = self.dataset_versions[dataset_name] + item["metadata"]["version_id"] = version_id + # Add entry_id combining version and index + item["metadata"]["entry_id"] = f"{version_id}.{idx}" + return item def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None: @@ -116,23 +136,151 @@ class CompositeDataset(ProceduralDataset): dataset = self.datasets[dataset_name] - # Create new config with updates - new_config = dataset.config.__class__(**vars(dataset.config)) - for key, value in config_updates.items(): - setattr(new_config, key, value) + # Update the current config + new_config = replace(dataset.config, **config_updates) # Validate new config new_config.validate() # Create new dataset instance with updated config dataset_cls = dataset.__class__ - self.datasets[dataset_name] = dataset_cls(new_config) + new_dataset = dataset_cls(new_config) + self.datasets[dataset_name] = new_dataset + + # Register new version if tracking enabled + if self.version_manager is not None: + version_id = self.version_manager.register_dataset(dataset_name, new_dataset) + self.dataset_versions[dataset_name] = version_id + + def update_dataset_weight(self, dataset_name: str, weight: float) -> None: + """Update weight for a specific dataset in the configuration + + Args: + dataset_name: Name of the dataset to update + weight: New weight value + + Raises: + KeyError: If dataset_name not found + ValueError: If weight is negative + """ + if dataset_name not in self.datasets: + raise KeyError(f"Dataset '{dataset_name}' not found") + if weight < 0: + raise ValueError(f"Weight must be non-negative, got {weight}") + + # Update weight in both config and weights list + for i, ds_spec in enumerate(self.config.datasets): + if ds_spec.name == dataset_name: + ds_spec.weight = weight + self.weights[i] = weight + break def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: """Forward scoring to appropriate dataset""" dataset_name = entry["metadata"]["source_dataset"] return self.datasets[dataset_name].score_answer(answer, entry) + def add_dataset(self, dataset_spec: DatasetSpec) -> None: + """Add a new dataset to the composite + + Args: + dataset_spec: Specification for the dataset to add + + Raises: + ValueError: If dataset name already exists + """ + # Validate spec + dataset_spec.validate() + + # Check for duplicate name + if dataset_spec.name in self.datasets: + raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite") + + # Create dataset with derived seed + ds_config = dataset_spec.config.copy() + if "seed" not in ds_config: + ds_config["seed"] = self.seed + len(self.datasets) + 1 + if "size" not in ds_config: + ds_config["size"] = self.size + + # Create and add dataset + dataset = create_dataset(dataset_spec.name, **ds_config) + self.datasets[dataset_spec.name] = dataset + + # Register version if tracking enabled + if self.version_manager is not None: + version_id = self.version_manager.register_dataset(dataset_spec.name, dataset) + self.dataset_versions[dataset_spec.name] = version_id + + # Add to config and update internal state + self.config.datasets.append(dataset_spec) + self.dataset_names.append(dataset_spec.name) + self.weights.append(dataset_spec.weight) # Use weight directly from spec + + def remove_dataset(self, dataset_name: str) -> None: + """Remove a dataset from the composite + + Args: + dataset_name: Name of the dataset to remove + + Raises: + KeyError: If dataset not found + ValueError: If trying to remove last dataset + """ + if dataset_name not in self.datasets: + raise KeyError(f"Dataset '{dataset_name}' not found") + + if len(self.datasets) <= 1: + raise ValueError("Cannot remove last dataset from composite") + + # Remove from all internal structures + del self.datasets[dataset_name] + if self.version_manager is not None: + del self.dataset_versions[dataset_name] + + # Remove from config + self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name] + + # Update internal state + idx = self.dataset_names.index(dataset_name) + self.dataset_names.pop(idx) + self.weights.pop(idx) + + def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float: + """Score an answer using an entry_id to lookup the original entry + + Args: + answer: The answer to score + entry_id: String in format "version_id.index" + + Returns: + Score between 0 and 1 + + Raises: + ValueError: If entry_id format is invalid + KeyError: If version not found in version manager + """ + if self.version_manager is None: + raise RuntimeError("Version manager required for scoring with entry_id") + + try: + version_id, index = map(int, entry_id.split(".")) + except ValueError: + raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'") + + # Get dataset from version manager + dataset_info = self.version_manager.get_dataset(version_id) + if dataset_info is None: + raise KeyError(f"Version {version_id} not found in version manager") + + dataset_name, dataset = dataset_info + + # Get entry from dataset + entry = dataset[index] + + # Score answer using dataset's scoring function + return dataset.score_answer(answer, entry) + # Register the dataset register_dataset("composite", CompositeDataset, CompositeConfig) diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index 2d9f5568..3ca1277c 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -1,5 +1,6 @@ """Mini Sudoku (4x4) puzzle generator""" +import copy from dataclasses import dataclass from random import Random from typing import Any, List, Optional, Tuple @@ -11,15 +12,18 @@ from ..factory import ProceduralDataset, register_dataset class MiniSudokuConfig: """Configuration for 4x4 sudoku puzzle generation""" - min_empty: int = 8 # Minimum number of empty cells + min_empty: int = ( + 8 # Minimum number of empty cells. Occasionally this can be violated, if removing more cells would break the puzzle's uniqueness. + ) max_empty: int = 12 # Maximum number of empty cells seed: Optional[int] = None size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" - assert 0 <= self.min_empty <= 16, "min_empty must be between 0 and 16" - assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16" + # More than 12 empty cells is incompatible with a unique solution + assert 0 <= self.min_empty <= 12, "min_empty must be between 0 and 12" + assert self.min_empty <= self.max_empty <= 12, "max_empty must be between min_empty and 12" class MiniSudokuDataset(ProceduralDataset): @@ -111,14 +115,45 @@ class MiniSudokuDataset(ProceduralDataset): raise RuntimeError("Failed to generate valid mini sudoku board") + def _count_solutions(self, board: List[List[int]], limit: int = 2) -> int: + """Count the number of solutions for a given board""" + + def _count_solutions_helper(board: List[List[int]]) -> int: + empty = self._find_empty(board) + if not empty: + return 1 + + row, col = empty + count = 0 + for num in range(1, 5): + if self._is_valid(board, row, col, num): + board[row][col] = num + count += _count_solutions_helper(board) + if count >= limit: + return count + board[row][col] = 0 + return count + + return _count_solutions_helper(board) + def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]: """Create puzzle by removing numbers from solved board""" puzzle = [row[:] for row in solved_board] cells = [(i, j) for i in range(4) for j in range(4)] rng.shuffle(cells) + num_removed = 0 - for i, j in cells[:num_empty]: + for i, j in cells: + saved = puzzle[i][j] puzzle[i][j] = 0 + puzzle_copy = copy.deepcopy(puzzle) + # Check if removing this clue breaks uniqueness + if self._count_solutions(puzzle_copy) > 1: + puzzle[i][j] = saved + else: + num_removed += 1 + if num_removed == num_empty: + break return puzzle @@ -137,6 +172,9 @@ class MiniSudokuDataset(ProceduralDataset): num_empty = rng.randint(self.config.min_empty, self.config.max_empty) puzzle = self._create_puzzle(solved_board, num_empty, rng) + # Update the num_empty to be used in the metadata if we couldn't remove as many as we wanted + num_empty = sum(1 for row in puzzle for x in row if x == 0) + # Format as strings puzzle_str = self._board_to_string(puzzle) solution_str = self._board_to_string(solved_board) diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index 9268546c..5efe79e7 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -1,15 +1,19 @@ """Sudoku puzzle generator""" +import copy from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple from ..factory import ProceduralDataset, register_dataset @dataclass class SudokuConfig: - """Configuration for sudoku puzzle generation""" + """ + Configuration for sudoku puzzle generation + Puzzle generation can be a bit slower for puzzles with a high (~60+) number of empty cells + """ min_empty: int = 30 # Minimum number of empty cells max_empty: int = 50 # Maximum number of empty cells @@ -18,8 +22,9 @@ class SudokuConfig: def validate(self): """Validate configuration parameters""" - assert 0 <= self.min_empty <= 81, "min_empty must be between 0 and 81" - assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81" + # 81 - 64 = 17, the minimum number of clues required for 9x9 Sudoku to have a unique solution + assert 0 <= self.min_empty <= 64, "min_empty must be between 0 and 64" + assert self.min_empty <= self.max_empty <= 64, "max_empty must be between min_empty and 64" class SudokuDataset(ProceduralDataset): @@ -60,6 +65,21 @@ class SudokuDataset(ProceduralDataset): return False return True + def _get_possible_values(self, board: List[List[int]], row: int, col: int) -> Set[int]: + """Get all possible values for a cell.""" + row_values = set(board[row]) + col_values = set(board[i][col] for i in range(9)) + + # Get filled values in the current 3x3 subgrid + box_row, box_col = 3 * (row // 3), 3 * (col // 3) + box_values = set() + for i in range(box_row, box_row + 3): + for j in range(box_col, box_col + 3): + box_values.add(board[i][j]) + + used_values = row_values | col_values | box_values + return set(range(1, 10)) - used_values + def _solve(self, board: List[List[int]]) -> bool: """Solve sudoku using backtracking""" empty = self._find_empty(board) @@ -67,12 +87,11 @@ class SudokuDataset(ProceduralDataset): return True row, col = empty - for num in range(1, 10): - if self._is_valid(board, row, col, num): - board[row][col] = num - if self._solve(board): - return True - board[row][col] = 0 + for num in self._get_possible_values(board, row, col): + board[row][col] = num + if self._solve(board): + return True + board[row][col] = 0 return False def _find_empty(self, board: List[List[int]]) -> Optional[Tuple[int, int]]: @@ -101,14 +120,66 @@ class SudokuDataset(ProceduralDataset): self._solve(board) return board + def _count_solutions(self, board: List[List[int]], limit: int = 2) -> int: + """Count the number of solutions for a given board""" + + def _get_min_possibilities_cell(board: List[List[int]]) -> Optional[Tuple[int, int, Set[int]]]: + """ + Get the cell with the lowest number of possibilities. + Returns None if the board is already solved. + """ + min_possibilities = 10 + min_cell = None + min_values = None + + for i in range(9): + for j in range(9): + if board[i][j] == 0: + possible = self._get_possible_values(board, i, j) + if len(possible) < min_possibilities: + min_possibilities = len(possible) + min_cell = (i, j) + min_values = possible + if min_possibilities == 1: + return (*min_cell, min_values) + + return (*min_cell, min_values) if min_cell else None + + def _count_solutions_helper(board: List[List[int]]) -> int: + cell_info = _get_min_possibilities_cell(board) + if not cell_info: + return 1 + + row, col, possible_values = cell_info + count = 0 + for num in possible_values: + board[row][col] = num + count += _count_solutions_helper(board) + if count >= limit: + return count + board[row][col] = 0 + return count + + return _count_solutions_helper(board) + def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]: """Create puzzle by removing numbers from solved board""" puzzle = [row[:] for row in solved_board] cells = [(i, j) for i in range(9) for j in range(9)] rng.shuffle(cells) + num_removed = 0 - for i, j in cells[:num_empty]: + for i, j in cells: + saved = puzzle[i][j] puzzle[i][j] = 0 + puzzle_copy = copy.deepcopy(puzzle) + # Check if removing this clue breaks uniqueness + if self._count_solutions(puzzle_copy) > 1: + puzzle[i][j] = saved + else: + num_removed += 1 + if num_removed == num_empty: + break return puzzle @@ -131,11 +202,51 @@ class SudokuDataset(ProceduralDataset): puzzle_str = self._board_to_string(puzzle) solution_str = self._board_to_string(solved_board) + question = ( + f"Solve this Sudoku puzzle:\n{puzzle_str}\n" + "Respond with only your answer, formatted as the puzzle, a 9x9 grid with numbers separated by spaces, and rows separated by newlines." + ) + return { - "question": f"Solve this Sudoku puzzle:\n{puzzle_str}", + "question": question, "answer": solution_str, "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, } + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + if not answer: + return 0.0 + + oracle_answer = entry["answer"] + metadata = entry["metadata"] + solution: list[list[int]] = metadata["solution"] + board_size: int = len(solution[0]) + + # 1. match answer without trailing whitespaces + answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n")) + oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n")) + + if answer_stripped == oracle_answer_stripped: + reward = 1.0 + else: + # 2. accept answers with correct numeric sequence (ignoring non-numeric characters) + row = 0 + num_matching = 0 + for ln in answer.split("\n"): + numbers = [int(c) for c in ln if c.isnumeric()] + if len(numbers) != board_size: + continue # ignore lines without numbers + for a, b in zip(solution[row], numbers): + if a == b: + num_matching += 1 + row += 1 + + reward = num_matching / (board_size * board_size) + reward *= 0.9 # penalty for not using standard format + + if len(answer) > len(oracle_answer): + reward *= len(oracle_answer) / len(answer) # penalty for additional length + return reward + register_dataset("sudoku", SudokuDataset, SudokuConfig) diff --git a/reasoning_gym/geometry/advanced_geometry.py b/reasoning_gym/geometry/advanced_geometry.py index ac8797b9..6f64daf8 100644 --- a/reasoning_gym/geometry/advanced_geometry.py +++ b/reasoning_gym/geometry/advanced_geometry.py @@ -1,7 +1,9 @@ import random +import re from dataclasses import dataclass, field -from typing import List, Optional +from typing import Any, Dict, List, Optional +import numpy as np import sympy from sympy.geometry import Point, Segment, Triangle @@ -35,6 +37,18 @@ class AdvancedGeometryConfig: assert len(self.task_types) > 0, "Must specify at least one task type." +# Join format instructions into a single string +GEOMETRY_FORMAT_INSTRUCTIONS = "\n".join( + [ + "For all geometry problems:", + "1. Give coordinates in the form (x, y)", + "2. Round decimal answers to 3 decimal places", + "3. Use the degree symbol ° for angles", + "4. Return only th angle, coordinates, or radius as your answer.", + ] +) + + class AdvancedGeometryDataset(ProceduralDataset): """ A dataset for advanced geometry tasks using coordinate geometry. @@ -43,16 +57,16 @@ class AdvancedGeometryDataset(ProceduralDataset): def __init__(self, config: AdvancedGeometryConfig): self._prompt_templates = { "orthocenter": [ - "Given triangle ABC with coordinates A={A}, B={B}, and C={C}, find the coordinates of its orthocenter.", - "For triangle with vertices A={A}, B={B}, and C={C}, determine the orthocenter (intersection of altitudes).", + f"Given triangle ABC with coordinates A={{A}}, B={{B}}, and C={{C}}, find the coordinates of its orthocenter. {GEOMETRY_FORMAT_INSTRUCTIONS}", + f"For triangle with vertices A={{A}}, B={{B}}, and C={{C}}, determine the orthocenter (intersection of altitudes). {GEOMETRY_FORMAT_INSTRUCTIONS}", ], "incircle_radius": [ - "Consider triangle ABC with coordinates A={A}, B={B}, and C={C}. Compute the radius of its incircle.", - "Find the incircle radius of triangle ABC whose vertices are A={A}, B={B}, and C={C}.", + f"Consider triangle ABC with coordinates A={{A}}, B={{B}}, and C={{C}}. Compute the radius of its incircle. {GEOMETRY_FORMAT_INSTRUCTIONS}", + f"Find the incircle radius of triangle ABC whose vertices are A={{A}}, B={{B}}, and C={{C}}. {GEOMETRY_FORMAT_INSTRUCTIONS}", ], "angle_measure": [ - "In triangle ABC with coordinates A={A}, B={B}, and C={C}, find the measure (in degrees) of angle ABC.", - "Given a triangle with vertices A={A}, B={B}, C={C}, determine the angle at B in degrees.", + f"In triangle ABC with coordinates A={{A}}, B={{B}}, and C={{C}}, find the measure (in degrees) of angle ABC. {GEOMETRY_FORMAT_INSTRUCTIONS}", + f"Given a triangle with vertices A={{A}}, B={{B}}, and C={{C}}, determine the angle at B in degrees. {GEOMETRY_FORMAT_INSTRUCTIONS}", ], } super().__init__(config=config, seed=config.seed, size=config.size) @@ -77,6 +91,8 @@ class AdvancedGeometryDataset(ProceduralDataset): else: raise ValueError(f"Unknown task_type: {task_type}") + metadata["task_type"] = task_type + return { "question": question, "answer": answer, @@ -127,13 +143,14 @@ class AdvancedGeometryDataset(ProceduralDataset): y_ortho_approx = float(ortho.y.evalf()) question_template = rng.choice(self._prompt_templates["orthocenter"]) - question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)) + question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y), a="a", b="b") answer_str = f"({x_ortho_approx:.3f}, {y_ortho_approx:.3f})" metadata = { "A": (A.x, A.y), "B": (B.x, B.y), "C": (C.x, C.y), + "ortho": (ortho.x, ortho.y), "orthocenter_exact": (str(ortho.x), str(ortho.y)), "orthocenter_approx": (x_ortho_approx, y_ortho_approx), } @@ -200,7 +217,7 @@ class AdvancedGeometryDataset(ProceduralDataset): angle_deg = float(angle_rad.evalf() * 180 / sympy.pi) question_template = rng.choice(self._prompt_templates["angle_measure"]) - question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)) + question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y), a="a", b="b") answer_str = f"{angle_deg:.2f}°" metadata = { @@ -211,6 +228,55 @@ class AdvancedGeometryDataset(ProceduralDataset): } return question, answer_str, metadata + def score_answer(self, answer: str | None, entry: Dict[str, Any]) -> float: + reward = 0.0 + expected_answer = entry["answer"] + metadata = entry["metadata"] + task_type = metadata["task_type"] + + if answer is not None: + try: + if metadata["task_type"] == "angle_measure": + answer = answer.replace("\u00b0", "") + expected_answer = expected_answer.replace("\u00b0", "") + if np.round(float(answer), 2) == np.round(float(expected_answer), 2): + reward = 1.0 + elif len(answer.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + elif metadata["task_type"] == "orthocenter": + x_coord = answer.split(",")[0].replace("(", "").strip() + y_coord = answer.split(",")[1].replace(")", "").strip() + + expected_x = metadata["ortho"][0] + expected_y = metadata["ortho"][1] + + if x_coord == expected_x and y_coord == expected_y: + reward = 1.0 + elif (np.round(float(x_coord), 2) == np.round(float(expected_x), 2)) and ( + np.round(float(y_coord), 2) == np.round(float(expected_y), 2) + ): + reward = 1.0 + elif len(x_coord.strip()) > 0 and len(y_coord.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + elif metadata["task_type"] == "incircle_radius": + if answer == expected_answer: + reward = 1.0 + elif np.round(float(answer), 2) == np.round(float(metadata["incircle_radius_exact"]), 2): + reward = 0.5 + elif len(answer.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + else: + raise ValueError(f"Unknown task_type: {task_type}") + except: + reward = 0.01 + return reward + # Register the dataset register_dataset("advanced_geometry", AdvancedGeometryDataset, AdvancedGeometryConfig) diff --git a/reasoning_gym/geometry/simple_geometry.py b/reasoning_gym/geometry/simple_geometry.py index d04912d7..665a440f 100644 --- a/reasoning_gym/geometry/simple_geometry.py +++ b/reasoning_gym/geometry/simple_geometry.py @@ -46,15 +46,21 @@ class SimpleGeometryDataset(ProceduralDataset): ( "Given a convex polygon with {n_sides} sides, its first {n_minus_1} interior angles " "are: {angle_list}. What is the measure of the remaining interior angle (in degrees)?" + "Return only the angle as your answer." + "Do not give the units in your answer." ), ( "A convex polygon has {n_sides} sides. The measures of " "the first {n_minus_1} interior angles are: {angle_list}. " "Find the measure of the last interior angle." + "Return only the angle as your answer." + "Do not give the units in your answer." ), ( "Consider a convex {n_sides}-gon whose first {n_minus_1} interior angles " "are: {angle_list}. Determine the measure of the remaining angle." + "Return only the angle as your answer." + "Do not give the units in your answer." ), ] super().__init__(config=config, seed=config.seed, size=config.size) diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py index d672c922..38cbe051 100644 --- a/reasoning_gym/logic/zebra_puzzles.py +++ b/reasoning_gym/logic/zebra_puzzles.py @@ -43,7 +43,8 @@ class ZebraDataset(ProceduralDataset): instance, puzzle = generate_puzzle(rng, K, M) q = instance["questions"][0]["question"] answer = instance["questions"][0]["answer"] - question = str(puzzle) + "\n" + q + "\nReply only with your final answer, which should be the name of a person." + question = str(puzzle) + "\n" + q + question = question + "? Provide only the name of the person as your final answer." return { "question": question, diff --git a/reasoning_gym/version_manager.py b/reasoning_gym/version_manager.py new file mode 100644 index 00000000..dbe19a09 --- /dev/null +++ b/reasoning_gym/version_manager.py @@ -0,0 +1,76 @@ +"""Version manager for tracking dataset versions.""" + +from typing import Dict, Optional, Tuple + +from .dataset import ProceduralDataset + + +class DatasetVersionManager: + """Manages versioned ProceduralDataset instances and their configurations.""" + + def __init__(self): + """Initialize the version manager.""" + self.current_version = 0 + # version_id -> (dataset_name, dataset_instance) + self.datasets: Dict[int, Tuple[str, ProceduralDataset]] = {} + + def register_dataset(self, name: str, dataset: ProceduralDataset) -> int: + """ + Register a new dataset version. + + Args: + name: Name/identifier of the dataset type + dataset: Instance of ProceduralDataset + + Returns: + version_id: Unique identifier for this dataset version + """ + self.current_version += 1 + self.datasets[self.current_version] = (name, dataset) + return self.current_version + + def get_dataset(self, version_id: int) -> Optional[Tuple[str, ProceduralDataset]]: + """ + Retrieve a dataset by its version ID. + + Args: + version_id: The version identifier + + Returns: + Tuple of (dataset_name, dataset_instance) if found, None otherwise + """ + return self.datasets.get(version_id) + + def get_entry(self, version_id: int, index: int) -> Dict[str, any]: + """ + Get a specific entry from a versioned dataset. + + Args: + version_id: The version identifier + index: Index of the entry to retrieve + + Returns: + The dataset entry + + Raises: + KeyError: If version_id is not found + """ + if version_id not in self.datasets: + raise KeyError(f"Dataset version {version_id} not found") + + _, dataset = self.datasets[version_id] + return dataset[index] + + def cleanup_old_versions(self, keep_latest: int = 10): + """ + Remove old dataset versions to free memory. + + Args: + keep_latest: Number of most recent versions to keep + """ + if len(self.datasets) <= keep_latest: + return + + versions_to_remove = sorted(self.datasets.keys())[:-keep_latest] + for version in versions_to_remove: + del self.datasets[version] diff --git a/tests/test_composite.py b/tests/test_composite.py index cbfec38a..93cc6f0b 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -4,6 +4,7 @@ import pytest import yaml from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec +from reasoning_gym.version_manager import DatasetVersionManager def create_test_config(tmp_path): @@ -85,13 +86,165 @@ def test_composite_dataset_weights(): seed=42, datasets=[ DatasetSpec("chain_sum", 2.0, {"min_terms": 2}), - DatasetSpec("chain_sum", 3.0, {"min_terms": 3}), + DatasetSpec("products", 3.0, {"min_terms": 2}), ], ) dataset = CompositeDataset(config) - assert abs(dataset.weights[0] - 0.4) < 1e-6 - assert abs(dataset.weights[1] - 0.6) < 1e-6 + assert abs(dataset.weights[0] - 2.0) < 1e-6 + assert abs(dataset.weights[1] - 3.0) < 1e-6 + + # Test weight updates + dataset.update_dataset_weight("chain_sum", 1.0) + print(dataset.weights) + assert abs(dataset.weights[0] - 1.0) < 1e-6 + assert abs(dataset.weights[1] - 3.0) < 1e-6 + + # Test invalid weight + with pytest.raises(ValueError, match="Weight must be non-negative"): + dataset.update_dataset_weight("chain_sum", -1.0) + + # Test invalid dataset name + with pytest.raises(KeyError): + dataset.update_dataset_weight("invalid_dataset", 1.0) + + # Test zero total weight + dataset.update_dataset_weight("chain_sum", 0.0) + with pytest.raises(ValueError, match="Total of weights must be greater than zero"): + dataset.update_dataset_weight("products", 0.0) + _ = dataset[0] # access item with all weights 0 + + # Test duplicate dataset names + with pytest.raises(ValueError, match="Duplicate dataset names"): + CompositeConfig( + size=1000, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2}), + DatasetSpec("chain_sum", 1.0, {"min_terms": 3}), + ], + ).validate() + + +def test_version_tracking_with_config_updates(): + """Test that version tracking works correctly when updating dataset configs""" + # Create composite dataset with version manager + version_manager = DatasetVersionManager() + config = CompositeConfig( + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] + ) + dataset = CompositeDataset(config, version_manager=version_manager) + + # Get an entry and its id from initial version + entry_1 = dataset[0] + entry_id_1 = entry_1["metadata"]["entry_id"] + answer_1 = entry_1["answer"] + + # Update dataset config + dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5}) + + # Get new entry after config update + entry_2 = dataset[0] + entry_id_2 = entry_2["metadata"]["entry_id"] + answer_2 = entry_2["answer"] + + # Verify entries have different version IDs + version_1 = int(entry_id_1.split(".")[0]) + version_2 = int(entry_id_2.split(".")[0]) + assert version_1 != version_2, "New config should create new version" + + # Verify original answer still works with original version + score_1 = dataset.score_answer_with_id(answer_1, entry_id_1) + assert score_1 == 1.0, "Original answer should still work with original version" + + # Verify new answer works with new version + score_2 = dataset.score_answer_with_id(answer_2, entry_id_2) + assert score_2 == 1.0, "New answer should work with new version" + + # Verify original answer fails with new version + score_3 = dataset.score_answer_with_id(answer_1, entry_id_2) + assert score_3 < 1.0, "Original answer should not work with new version" + + +def test_score_answer_with_id(): + """Test scoring answers using entry_id""" + # Create composite dataset with version manager + version_manager = DatasetVersionManager() + config = CompositeConfig( + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] + ) + dataset = CompositeDataset(config, version_manager=version_manager) + + # Get an entry and its id + entry = dataset[0] + entry_id = entry["metadata"]["entry_id"] + + # Test successful scoring + answer = entry["answer"] + score = dataset.score_answer_with_id(answer, entry_id) + assert score == 1.0 # Correct answer should get full score + + # Test wrong answer + wrong_answer = "wrong" + score = dataset.score_answer_with_id(wrong_answer, entry_id) + assert score < 1.0 # Wrong answer should get lower score + + # Test invalid entry_id format + with pytest.raises(ValueError, match="Invalid entry_id format"): + dataset.score_answer_with_id(answer, "invalid") + + # Test non-existent version + with pytest.raises(KeyError, match="Version .* not found"): + dataset.score_answer_with_id(answer, "999.0") + + # Test without version manager + dataset_no_vm = CompositeDataset(config) + with pytest.raises(RuntimeError, match="Version manager required"): + dataset_no_vm.score_answer_with_id(answer, entry_id) + + +def test_add_remove_dataset(): + """Test adding and removing datasets from composite""" + config = CompositeConfig( + size=1000, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2}), + ], + ) + + dataset = CompositeDataset(config) + + # Test adding new dataset + new_spec = DatasetSpec("products", 2.0, {"min_terms": 2}) + dataset.add_dataset(new_spec) + + assert len(dataset.datasets) == 2 + assert "products" in dataset.datasets + assert len(dataset.config.datasets) == 2 + + assert dataset.dataset_names[0] == "chain_sum" + assert dataset.dataset_names[1] == "products" + assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight + assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight + + # Test duplicate name + with pytest.raises(ValueError, match="already exists"): + dataset.add_dataset(new_spec) + + # Test removing dataset + dataset.remove_dataset("products") + assert len(dataset.datasets) == 1 + assert "products" not in dataset.datasets + assert len(dataset.config.datasets) == 1 + + # Test removing non-existent dataset + with pytest.raises(KeyError): + dataset.remove_dataset("nonexistent") + + # Test removing last dataset + with pytest.raises(ValueError, match="Cannot remove last dataset"): + dataset.remove_dataset("chain_sum") def test_yaml_loading(tmp_path): diff --git a/tests/test_decimal_chain_sum.py b/tests/test_decimal_chain_sum.py new file mode 100644 index 00000000..5114a7c7 --- /dev/null +++ b/tests/test_decimal_chain_sum.py @@ -0,0 +1,252 @@ +import pytest + +from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset + + +def test_decimal_chain_sum_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = DecimalChainSumConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = DecimalChainSumConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_decimal_chain_sum_deterministic(): + """Test that dataset generates same items with same seed""" + config = DecimalChainSumConfig(seed=42, size=10) + dataset1 = DecimalChainSumDataset(config) + dataset2 = DecimalChainSumDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_decimal_chain_sum_items(): + """Test basic properties of generated items""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=4, + min_digits=1, + max_digits=2, + min_decimal_places=1, + max_decimal_places=2, + size=100, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only + and - are used + expression = item["metadata"]["expression"] + assert all(op in ["+", "-", " ", "."] or op.isdigit() for op in expression) + + # Check for floating point errors + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + # Verify no numbers have more decimal places than max_decimal_places + if "." in num: + decimal_places = len(num.split(".")[-1]) + assert decimal_places <= config.max_decimal_places, f"Number {num} has more decimal places than allowed" + + # Verify answer has correct precision + answer_str = item["answer"] + if "." in answer_str: + decimal_places = len(answer_str.split(".")[-1]) + assert ( + decimal_places <= config.max_decimal_places + ), f"Answer {answer_str} has more decimal places than allowed" + + # Verify mathematical correctness within epsilon + expected = eval(expression) + assert ( + abs(float(item["answer"]) - expected) < 1e-10 + ), f"Answer {item['answer']} doesn't match expected {expected}" + + +def test_chain_sum_number_ranges(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, + max_digits=3, + min_decimal_places=1, + max_decimal_places=4, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + +def test_decimal_chain_sum_negation(): + """Test that negation is properly handled""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + allow_negation=True, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + has_positive = False + has_negative = False + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [float(n) for n in expression.split() if n.replace(".", "").replace("-", "").isdigit()] + for num in numbers: + if num > 0: + has_positive = True + if num < 0: + has_negative = True + + assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True" + + +def test_decimal_chain_sum_iteration(): + """Test that iteration respects dataset size""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + size=5, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_decimal_places_generation(): + """Test that generated decimal numbers have correct number of decimal places""" + # Test fixed decimal places + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=2, + max_decimal_places=2, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for item in dataset: + expression = item["metadata"]["expression"] + # Extract numbers including decimals + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + decimal_part = num.split(".")[-1] + assert len(decimal_part) == 2, f"Number {num} should have exactly 2 decimal places" + + # Test varying decimal places + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=1, + max_decimal_places=3, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for item in dataset: + expression = item["metadata"]["expression"] + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + decimal_part = num.split(".")[-1] + assert 1 <= len(decimal_part) <= 3, f"Number {num} should have between 1 and 3 decimal places" + + +def test_decimal_precision_scoring(): + """Test that scoring handles decimal precision correctly""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=2, + max_decimal_places=3, + size=1, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + item = dataset[0] + + # Test exact matches with different representations + assert dataset.score_answer("1.200", {"answer": "1.2"}) == 1.0 + assert dataset.score_answer("1.20", {"answer": "1.200"}) == 1.0 + assert dataset.score_answer("-0.5", {"answer": "-0.500"}) == 1.0 + + # Test floating point precision edge cases + assert dataset.score_answer("0.1", {"answer": "0.100"}) == 1.0 + assert dataset.score_answer("0.3", {"answer": "0.300"}) == 1.0 + + # Test incorrect answers + assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.01 + + # Test invalid inputs + assert dataset.score_answer(None, {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("", {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.01 diff --git a/tests/test_number_format.py b/tests/test_number_format.py new file mode 100644 index 00000000..882f38aa --- /dev/null +++ b/tests/test_number_format.py @@ -0,0 +1,121 @@ +"""Tests for Number Format questions generation""" + +import pytest + +from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatDataset + + +def test_number_format_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_num_candidates=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_num_candidates=1) # One not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(min_n=10, max_n=5) # min > max + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_delta=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_delta=0) # Zero not allowed + config.validate() + + +def test_number_format_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = NumberFormatConfig(seed=42, size=10) + dataset1 = NumberFormatDataset(config) + dataset2 = NumberFormatDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_number_format_dataset_items(): + """Test basic properties of generated items""" + config = NumberFormatConfig(min_n=1_000, max_n=10_000, max_delta=1, size=10, seed=42) + dataset = NumberFormatDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "candidates" in item["metadata"] + assert "formatted_candidates" in item["metadata"] + assert "size" in item["metadata"] + assert "solution" in item["metadata"] + + candidates = item["metadata"]["candidates"] + formatted_candidates = item["metadata"]["formatted_candidates"] + size = item["metadata"]["size"] + solution = item["metadata"]["solution"] + + # Verify values + assert len(candidates) >= 2 + assert all(999 <= c <= 10_001 for c in candidates) # boundaries +- delta + assert len(candidates) == len(formatted_candidates) + assert size in ["largest", "smallest"] + assert solution in candidates + + +def test_number_format_dataset_iteration(): + """Test that iteration respects dataset size""" + config = NumberFormatConfig(size=5, seed=42) + dataset = NumberFormatDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_number_format_answer(): + """Verify the solution scoring""" + config = NumberFormatConfig(size=5, seed=42) + dataset = NumberFormatDataset(config) + + entry = {"metadata": {"solution": 54245.32}} + + # Correct answer (plain) + model_answer = "54245.32" + assert dataset.score_answer(model_answer, entry) == 1.0 + + # Correct answer (English) + model_answer = "54,245.32" + assert dataset.score_answer(model_answer, entry) == 1.0 + + # Correct answer (scientific) + assert dataset.score_answer("5.424532e+04", entry) == 1.0 + + # Incorrect answer (diff larger than 1e-2) + model_answer = "54245.9" + assert dataset.score_answer(model_answer, entry) == 0.01 + + # Answer is null + model_answer = None + assert dataset.score_answer(model_answer, entry) == 0.0 + + # Answer is unparsable + model_answer = "test" + assert dataset.score_answer(model_answer, entry) == 0.0 diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..f61171c1 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,83 @@ +# Reasoning Gym Tools + +This directory contains additional tools for working with Reasoning Gym: + +## Server + +A FastAPI server that manages reasoning gym experiments, allowing runtime configuration and monitoring. + +### Starting the Server + +1. Install server dependencies: +```bash +pip install -e ".[server]" +``` + +2. Set the API key environment variable: +```bash +export REASONING_GYM_API_KEY=your-secret-key +``` + +3. Start the server: +```bash +uvicorn tools.server.server:app +``` + +The server will be available at http://localhost:8000. You can access the API documentation at http://localhost:8000/docs. + +## RGC (Reasoning Gym Client) + +A command-line interface for interacting with the Reasoning Gym server. + +### Installation + +```bash +pip install -e ".[cli]" +``` + +### Usage + +First, set the API key to match your server: +```bash +export REASONING_GYM_API_KEY=your-secret-key +``` + +Then you can use the CLI: + +```bash +# List all commands +rgc --help + +# List experiments +rgc experiments list + +# Create a new experiment interactively +rgc experiments create my-experiment + +# Create from config file +rgc experiments create my-experiment -f config.yaml + +# Show experiment details +rgc experiments show my-experiment + +# Edit dataset configuration +rgc config edit my-experiment chain_sum +``` + +### Example Configuration File + +Here's an example `config.yaml` for creating an experiment: + +```yaml +size: 500 +seed: 42 +datasets: + chain_sum: + weight: 1.0 + config: + min_terms: 2 + max_terms: 4 + min_digits: 1 + max_digits: 2 + allow_negation: false +``` diff --git a/tools/cli/__init__.py b/tools/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/cli/rgc/__init__.py b/tools/cli/rgc/__init__.py new file mode 100644 index 00000000..7286e761 --- /dev/null +++ b/tools/cli/rgc/__init__.py @@ -0,0 +1,5 @@ +"""Reasoning Gym CLI tool.""" + +from .main import main + +__all__ = ["main"] diff --git a/tools/cli/rgc/client.py b/tools/cli/rgc/client.py new file mode 100644 index 00000000..4808a0fe --- /dev/null +++ b/tools/cli/rgc/client.py @@ -0,0 +1,125 @@ +"""HTTP client for interacting with the Reasoning Gym server.""" + +import os +from typing import List, Optional + +import httpx +from rich.console import Console + +from tools.server.models import ( + AnswerItem, + BatchResponse, + DatasetConfigUpdate, + ExperimentCreate, + ExperimentList, + ExperimentResponse, + ScoringRequest, + ScoringResponse, +) + +console = Console() + +DEFAULT_SERVER = "http://localhost:8000" +API_KEY = os.getenv("REASONING_GYM_API_KEY", "default-key") + + +class RGClient: + """Client for interacting with Reasoning Gym server.""" + + def __init__(self, base_url: str = DEFAULT_SERVER, api_key: str = API_KEY): + """Initialize client with server URL and API key.""" + self.base_url = base_url.rstrip("/") + self.headers = {"X-API-Key": api_key} + + def _url(self, path: str) -> str: + """Construct full URL for given path.""" + return f"{self.base_url}/{path.lstrip('/')}" + + def check_health(self) -> bool: + """Check server health status.""" + try: + response = httpx.get(self._url("/health"), headers=self.headers) + response.raise_for_status() + return response.json()["status"] == "healthy" + except Exception: + return False + + def list_experiments(self) -> ExperimentList: + """List all registered experiments.""" + response = httpx.get(self._url("/experiments"), headers=self.headers) + response.raise_for_status() + return ExperimentList.model_validate(response.json()) + + def create_experiment(self, name: str, config: ExperimentCreate) -> ExperimentResponse: + """Create a new experiment.""" + response = httpx.post( + self._url("/experiments"), + headers=self.headers, + json=config.model_dump(), + ) + response.raise_for_status() + return ExperimentResponse.model_validate(response.json()) + + def delete_experiment(self, name: str) -> None: + """Delete an experiment.""" + response = httpx.delete( + self._url(f"/experiments/{name}"), + headers=self.headers, + ) + response.raise_for_status() + + def get_experiment_config(self, name: str) -> ExperimentResponse: + """Get experiment configuration.""" + response = httpx.get( + self._url(f"/experiments/{name}/composite"), + headers=self.headers, + ) + response.raise_for_status() + return ExperimentResponse.model_validate(response.json()) + + def update_dataset_config(self, experiment: str, dataset: str, config: DatasetConfigUpdate) -> None: + """Update dataset configuration.""" + response = httpx.post( + self._url(f"/experiments/{experiment}/composite/{dataset}"), + headers=self.headers, + json=config.model_dump(), + ) + response.raise_for_status() + + def get_batch(self, experiment: str, base_index: int, batch_size: int) -> BatchResponse: + """Get a batch of entries from an experiment. + + Args: + experiment: Name of the experiment + base_index: Starting index for the batch + batch_size: Number of entries to retrieve + + Returns: + BatchResponse containing entries with questions and metadata + """ + response = httpx.get( + self._url(f"/experiments/{experiment}/batch"), + headers=self.headers, + params={"base_index": base_index, "batch_size": batch_size}, + ) + response.raise_for_status() + return BatchResponse.model_validate(response.json()) + + def score_outputs(self, experiment: str, entry_answers: List[AnswerItem]) -> ScoringResponse: + """Score a batch of answers. + + Args: + experiment: Name of the experiment + entry_answers: List of AnswerItems with entry_ids and answers to score + + Returns: + ScoringResponse containing scores and entry_ids + """ + request = ScoringRequest(answers=entry_answers) + response = httpx.post( + self._url(f"/experiments/{experiment}/score"), + headers=self.headers, + json=request.model_dump(), + ) + response.raise_for_status() + return ScoringResponse.model_validate(response.json()) diff --git a/tools/cli/rgc/main.py b/tools/cli/rgc/main.py new file mode 100644 index 00000000..827c413a --- /dev/null +++ b/tools/cli/rgc/main.py @@ -0,0 +1,231 @@ +"""Main entry point for the Reasoning Gym CLI.""" + +import os +from typing import Optional + +import typer +import yaml +from rich.console import Console +from rich.prompt import Confirm, Prompt +from rich.syntax import Syntax +from rich.table import Table + +from tools.server.models import DatasetConfigUpdate, ExperimentCreate + +# Initialize Typer apps +app = typer.Typer( + name="rgc", + help="Reasoning Gym CLI - Manage and monitor reasoning gym experiments", + add_completion=True, +) +experiments_app = typer.Typer(help="Manage experiments") +config_app = typer.Typer(help="Manage configurations") + +app.add_typer(experiments_app, name="experiments") +app.add_typer(config_app, name="config") + + +@app.command("health") +def check_health(): + """Check server connection and health status.""" + try: + if client.check_health(): + console.print("[green]Server is healthy[/]") + else: + console.print("[red]Server is not responding correctly[/]") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Error connecting to server: {e}[/]") + raise typer.Exit(1) + + +# Initialize client and console +from .client import RGClient + +client = RGClient() +console = Console() + + +@experiments_app.command("list") +def list_experiments(): + """List all registered experiments with their status.""" + table = Table(title="Registered Experiments") + table.add_column("Name", style="cyan") + table.add_column("Datasets", style="magenta") + table.add_column("Size", style="blue") + table.add_column("Seed", style="green") + + try: + experiments = client.list_experiments() + for exp_name in experiments.experiments: + try: + config = client.get_experiment_config(exp_name) + datasets = ", ".join(config.datasets.keys()) + table.add_row(exp_name, datasets, str(config.size), str(config.seed or "")) + except Exception as e: + console.print(f"[yellow]Warning: Could not get config for {exp_name}: {e}[/]") + table.add_row(exp_name, "?", "?", "?") + except Exception as e: + console.print(f"[red]Error listing experiments: {e}[/]") + raise typer.Exit(1) + + console.print(table) + + +@experiments_app.command("create") +def create_experiment( + name: str = typer.Argument(..., help="Name of the experiment"), + config_file: Optional[str] = typer.Option(None, "--file", "-f", help="YAML configuration file"), +): + """Create a new experiment.""" + if config_file: + try: + with open(config_file, "r") as f: + exp_config = yaml.safe_load(f) + config = ExperimentCreate(**exp_config) + response = client.create_experiment(name, config) + console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]") + except Exception as e: + console.print(f"[red]Error creating experiment: {e}[/]") + raise typer.Exit(1) + else: + # Interactive creation + size = Prompt.ask("Dataset size", default="500") + seed = Prompt.ask("Random seed (optional)", default="") + + datasets = {} + while Confirm.ask("Add dataset?"): + ds_name = Prompt.ask("Dataset name") + weight = float(Prompt.ask("Weight", default="1.0")) + + # Get dataset-specific config + console.print("\nEnter dataset configuration:") + config = {} + while Confirm.ask("Add config parameter?"): + key = Prompt.ask("Parameter name") + value = Prompt.ask("Parameter value") + try: + # Try to convert to appropriate type + if value.isdigit(): + value = int(value) + elif value.lower() in ("true", "false"): + value = value.lower() == "true" + elif "." in value and value.replace(".", "").isdigit(): + value = float(value) + except ValueError: + pass + config[key] = value + + datasets[ds_name] = {"weight": weight, "config": config} + + # Create experiment config + exp_config = {"name": name, "size": int(size), "seed": int(seed) if seed else None, "datasets": datasets} + + # Show final config + console.print("\nFinal configuration:") + console.print(Syntax(yaml.dump(exp_config), "yaml")) + + if Confirm.ask("Create experiment with this configuration?"): + try: + config = ExperimentCreate(**exp_config) + response = client.create_experiment(name, config) + console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]") + except Exception as e: + console.print(f"[red]Error creating experiment: {e}[/]") + raise typer.Exit(1) + else: + console.print("[yellow]Experiment creation cancelled[/]") + raise typer.Exit() + + +@experiments_app.command("delete") +def delete_experiment( + name: str = typer.Argument(..., help="Name of the experiment to delete"), + force: bool = typer.Option(False, "--force", "-f", help="Force deletion without confirmation"), +): + """Delete an experiment.""" + if not force and not Confirm.ask(f"Delete experiment [cyan]{name}[/]?"): + raise typer.Exit() + + try: + client.delete_experiment(name) + console.print(f"[green]Deleted experiment[/] [cyan]{name}[/]") + except Exception as e: + console.print(f"[red]Error deleting experiment: {e}[/]") + raise typer.Exit(1) + + +@experiments_app.command("show") +def show_experiment( + name: str = typer.Argument(..., help="Name of the experiment"), +): + """Show experiment details.""" + try: + config = client.get_experiment_config(name) + console.print(Syntax(yaml.dump(config.model_dump()), "yaml")) + except Exception as e: + console.print(f"[red]Error getting experiment config: {e}[/]") + raise typer.Exit(1) + + +@config_app.command("edit") +def edit_config( + experiment: str = typer.Argument(..., help="Name of the experiment"), + dataset: str = typer.Argument(..., help="Name of the dataset to edit"), +): + """Interactive configuration editor.""" + try: + exp_config = client.get_experiment_config(experiment) + if dataset not in exp_config.datasets: + console.print(f"[red]Dataset {dataset} not found in experiment[/]") + raise typer.Exit(1) + current_config = exp_config.datasets[dataset]["config"] + + console.print(f"\nCurrent configuration for [cyan]{dataset}[/]:") + console.print(Syntax(yaml.dump(current_config), "yaml")) + + # Interactive editing + new_config = {} + for key, value in current_config.items(): + new_value = Prompt.ask(f"{key}", default=str(value), show_default=True) + + # Try to convert to appropriate type + try: + if isinstance(value, bool): + new_value = new_value.lower() == "true" + elif isinstance(value, int): + new_value = int(new_value) + elif isinstance(value, float): + new_value = float(new_value) + except ValueError: + console.print(f"[yellow]Warning: Could not convert {new_value} to {type(value)}[/]") + + new_config[key] = new_value + + # Show changes + console.print("\nNew configuration:") + console.print(Syntax(yaml.dump(new_config), "yaml")) + + if Confirm.ask("Apply these changes?"): + try: + config_update = DatasetConfigUpdate(config=new_config) + client.update_dataset_config(experiment, dataset, config_update) + console.print("[green]Configuration updated successfully[/]") + except Exception as e: + console.print(f"[red]Error updating configuration: {e}[/]") + raise typer.Exit(1) + else: + console.print("[yellow]Update cancelled[/]") + + except Exception as e: + console.print(f"[red]Error getting experiment configuration: {e}[/]") + raise typer.Exit(1) + + +def main(): + """Entry point for the CLI.""" + app() + + +if __name__ == "__main__": + main() diff --git a/tools/server/__init__.py b/tools/server/__init__.py new file mode 100644 index 00000000..64926c4c --- /dev/null +++ b/tools/server/__init__.py @@ -0,0 +1,8 @@ +""" +Reasoning Gym Server - A FastAPI server for managing reasoning gym experiments. +""" + +from .config import ServerConfig +from .server import create_app + +__all__ = ["create_app", "ServerConfig"] diff --git a/tools/server/config.py b/tools/server/config.py new file mode 100644 index 00000000..5957b947 --- /dev/null +++ b/tools/server/config.py @@ -0,0 +1,17 @@ +"""Server configuration using Pydantic settings management.""" + +from pydantic import ConfigDict, Field +from pydantic_settings import BaseSettings + + +class ServerConfig(BaseSettings): + """Configuration settings for the Reasoning Gym server.""" + + host: str = Field(default="localhost", description="Server host address") + port: int = Field(default=8000, description="Server port") + api_key: str = Field( + default=..., description="API key for authentication", json_schema_extra={"env": "REASONING_GYM_API_KEY"} + ) + log_level: str = Field(default="INFO", description="Logging level") + + model_config = ConfigDict(env_prefix="REASONING_GYM_") diff --git a/tools/server/middleware.py b/tools/server/middleware.py new file mode 100644 index 00000000..24920cb6 --- /dev/null +++ b/tools/server/middleware.py @@ -0,0 +1,23 @@ +"""API key middleware for FastAPI.""" + +from fastapi import HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.status import HTTP_401_UNAUTHORIZED + + +class APIKeyMiddleware(BaseHTTPMiddleware): + """Middleware to check for valid API key in request headers.""" + + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next): + if request.url.path == "/health": + return await call_next(request) + + api_key = request.headers.get("X-API-Key") + if not api_key or api_key != self.api_key: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key") + + return await call_next(request) diff --git a/tools/server/models.py b/tools/server/models.py new file mode 100644 index 00000000..6c873b08 --- /dev/null +++ b/tools/server/models.py @@ -0,0 +1,75 @@ +"""Pydantic models for API request/response data.""" + +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field + + +class ExperimentCreate(BaseModel): + """Request model for creating a new experiment.""" + + name: str = Field(..., description="Unique name for the experiment") + size: int = Field(500, description="Size of the dataset") + seed: Optional[int] = Field(None, description="Random seed for reproducibility") + datasets: Dict[str, Dict[str, Any]] = Field(..., description="Dictionary of datasets configurations") + + +class ExperimentResponse(BaseModel): + """Response model for experiment operations.""" + + name: str = Field(..., description="Name of the experiment") + size: int = Field(..., description="Size of the dataset") + seed: Optional[int] = Field(None, description="Random seed used") + datasets: Dict[str, Dict[str, Any]] = Field(..., description="Current dataset configurations") + + +class ExperimentList(BaseModel): + """Response model for listing experiments.""" + + experiments: List[str] = Field(default_factory=list, description="List of registered experiment names") + + +class DatasetConfigUpdate(BaseModel): + """Request model for updating dataset configuration.""" + + config: Dict[str, Any] = Field(..., description="Configuration parameters to update") + + +class ErrorResponse(BaseModel): + """Response model for error conditions.""" + + detail: str = Field(..., description="Error message") + + +class BatchEntry(BaseModel): + """Single entry in a batch""" + + question: str = Field(..., description="The question text") + entry_id: str = Field(..., description="Unique identifier in format '{version}.{index}'") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the entry") + + +class BatchResponse(BaseModel): + """Response containing a batch of entries""" + + entries: List[BatchEntry] = Field(..., description="List of batch entries") + + +class AnswerItem(BaseModel): + """Single score item containing entry_id and answer""" + + entry_id: str = Field(..., description="Entry identifier to score") + answer: str = Field(..., description="Answer to evaluate") + + +class ScoringRequest(BaseModel): + """Request for scoring model outputs""" + + answers: List[AnswerItem] = Field(..., description="List of entries to score") + + +class ScoringResponse(BaseModel): + """Response containing scores for answers""" + + scores: List[float] = Field(..., description="List of scores in same order as request") + entry_ids: List[str] = Field(..., description="List of entry_ids in same order as request") diff --git a/tools/server/server.py b/tools/server/server.py new file mode 100644 index 00000000..09ded0d9 --- /dev/null +++ b/tools/server/server.py @@ -0,0 +1,169 @@ +"""FastAPI server implementation for Reasoning Gym.""" + +import logging + +from fastapi import FastAPI, HTTPException + +from reasoning_gym.coaching.registry import ExperimentRegistry +from reasoning_gym.composite import CompositeConfig, DatasetSpec + +from .config import ServerConfig +from .middleware import APIKeyMiddleware +from .models import ( + BatchEntry, + BatchResponse, + DatasetConfigUpdate, + ExperimentCreate, + ExperimentList, + ExperimentResponse, + ScoringRequest, + ScoringResponse, +) + + +def create_app(config: ServerConfig) -> FastAPI: + """Create and configure the FastAPI application.""" + + # Configure logging + logging.basicConfig(level=config.log_level) + logger = logging.getLogger(__name__) + + # Create FastAPI app + app = FastAPI(title="Reasoning Gym Server") + + # Add middleware + app.add_middleware(APIKeyMiddleware, api_key=config.api_key) + + # Initialize registry + registry = ExperimentRegistry() + + @app.get("/health") + async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + @app.post("/experiments", response_model=ExperimentResponse) + async def create_experiment(experiment: ExperimentCreate): + """Create a new experiment.""" + # Convert dict format to DatasetSpec list + dataset_specs = [] + for name, spec in experiment.datasets.items(): + dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {}))) + + config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs) + + try: + registry.register_experiment(experiment.name, config) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return ExperimentResponse( + name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets + ) + + @app.get("/experiments", response_model=ExperimentList) + async def list_experiments(): + """List all registered experiments.""" + return ExperimentList(experiments=registry.list_experiments()) + + @app.delete("/experiments/{name}") + async def delete_experiment(name: str): + """Delete an experiment.""" + if not registry.remove_experiment(name): + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + return {"status": "deleted"} + + @app.get("/experiments/{name}/batch", response_model=BatchResponse) + async def generate_batch(name: str, base_index: int, batch_size: int): + """Generate a batch of raw entries""" + # Validate parameters + if base_index < 0: + raise HTTPException(status_code=400, detail="base_index must be non-negative") + if batch_size <= 0: + raise HTTPException(status_code=400, detail="batch_size must be positive") + + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + entries = [] + for i in range(base_index, base_index + batch_size): + entry = experiment.dataset[i] + + # Create BatchEntry with minimal required data + batch_entry = BatchEntry( + question=entry["question"], + entry_id=f"{entry['metadata']['version_id']}.{i}", + metadata=entry["metadata"], + ) + entries.append(batch_entry) + + return BatchResponse(entries=entries) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + @app.post("/experiments/{name}/score", response_model=ScoringResponse) + async def score_outputs(name: str, request: ScoringRequest): + """Score extracted answers""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + scores = [] + entry_ids = [] + for item in request.answers: + score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id) + scores.append(score) + entry_ids.append(item.entry_id) + + return ScoringResponse(scores=scores, entry_ids=entry_ids) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + @app.get("/experiments/{name}/composite", response_model=ExperimentResponse) + async def get_composite_config(name: str): + """Get composite configuration for an experiment.""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + # Convert internal config to API response format + datasets = {} + for ds_spec in experiment.config.datasets: + dataset = experiment.dataset.datasets[ds_spec.name] + datasets[ds_spec.name] = { + "weight": ds_spec.weight, + "config": vars(dataset.config), # Get current config from dataset instance + } + + return ExperimentResponse( + name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets + ) + + @app.post("/experiments/{name}/composite/{dataset_name}") + async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate): + """Update configuration for a specific dataset in the composite.""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + experiment.dataset.update_dataset_config(dataset_name, config_update.config) + return {"status": "updated"} + except KeyError: + raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment") + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return app + + +async def app(scope, receive, send): + """ASGI application that lazily creates the FastAPI app.""" + if not hasattr(app, "server_app"): + app.server_app = create_app(ServerConfig()) + await app.server_app(scope, receive, send) diff --git a/tools/server/tests/__init__.py b/tools/server/tests/__init__.py new file mode 100644 index 00000000..f634e958 --- /dev/null +++ b/tools/server/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Reasoning Gym server.""" diff --git a/tools/server/tests/test_config.py b/tools/server/tests/test_config.py new file mode 100644 index 00000000..2c847522 --- /dev/null +++ b/tools/server/tests/test_config.py @@ -0,0 +1,27 @@ +"""Tests for server configuration.""" + +import os + +import pytest + +from ..config import ServerConfig + + +def test_default_config(): + """Test default configuration values.""" + os.environ["REASONING_GYM_API_KEY"] = "test-key" + config = ServerConfig() + + assert config.host == "localhost" + assert config.port == 8000 + assert config.api_key == "test-key" + assert config.log_level == "INFO" + + +def test_missing_api_key(): + """Test that missing API key raises an error.""" + if "REASONING_GYM_API_KEY" in os.environ: + del os.environ["REASONING_GYM_API_KEY"] + + with pytest.raises(ValueError): + ServerConfig() diff --git a/tools/server/tests/test_endpoints.py b/tools/server/tests/test_endpoints.py new file mode 100644 index 00000000..69cfee65 --- /dev/null +++ b/tools/server/tests/test_endpoints.py @@ -0,0 +1,277 @@ +"""Tests for API endpoints.""" + +import pytest +from fastapi.testclient import TestClient + +from ..config import ServerConfig +from ..server import create_app + + +@pytest.fixture +def client(): + """Create a test client.""" + config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO") + app = create_app(config) + return TestClient(app) + + +def test_health_check(client): + """Test health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_experiment_endpoints(client): + """Test experiment management endpoints.""" + # Set API key + headers = {"X-API-Key": "test-key"} + + # Create experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + assert response.json()["name"] == "test_exp" + + # List experiments + response = client.get("/experiments", headers=headers) + assert response.status_code == 200 + assert "test_exp" in response.json()["experiments"] + + # Delete experiment + response = client.delete("/experiments/test_exp", headers=headers) + assert response.status_code == 200 + + # Verify deletion + response = client.get("/experiments", headers=headers) + assert response.status_code == 200 + assert "test_exp" not in response.json()["experiments"] + + # Try to delete non-existent experiment + response = client.delete("/experiments/nonexistent", headers=headers) + assert response.status_code == 404 + + +def test_batch_generation_endpoint(client): + """Test batch generation endpoint.""" + headers = {"X-API-Key": "test-key"} + + # Create test experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Test batch generation + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 200 + data = response.json() + print(data) + + # Verify batch structure + assert "entries" in data + assert len(data["entries"]) == 2 + + # Verify entry structure + entry = data["entries"][0] + assert "question" in entry + assert "entry_id" in entry + assert "metadata" in entry + + # Test error cases + # Non-existent experiment + response = client.get( + "/experiments/nonexistent/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 404 + + # Invalid parameters + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": -1, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 400 + + +def test_scoring_endpoint(client): + """Test answer scoring endpoint.""" + headers = {"X-API-Key": "test-key"} + + # Create test experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Get a batch to get valid entry_ids + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 200 + batch = response.json() + entry_id = batch["entries"][0]["entry_id"] + + # Test scoring with correct answer + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question + headers=headers, + ) + assert response.status_code == 200 + result = response.json() + assert "scores" in result + assert "entry_ids" in result + assert len(result["scores"]) == 1 + assert len(result["entry_ids"]) == 1 + assert result["entry_ids"][0] == entry_id + assert isinstance(result["scores"][0], float) + assert 0 <= result["scores"][0] <= 1 + + # Test scoring with wrong answer + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]}, + headers=headers, + ) + assert response.status_code == 200 + result = response.json() + assert result["scores"][0] < 1.0 + assert result["entry_ids"][0] == entry_id + + # Test error cases + # Invalid entry_id format + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]}, + headers=headers, + ) + assert response.status_code == 400 + + # Non-existent experiment + response = client.post( + "/experiments/nonexistent/score", + json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, + headers=headers, + ) + assert response.status_code == 404 + + +def test_composite_config_endpoints(client): + """Test composite configuration endpoints.""" + headers = {"X-API-Key": "test-key"} + + # Create an experiment first + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Get composite config + response = client.get("/experiments/test_exp/composite", headers=headers) + assert response.status_code == 200 + config = response.json() + assert config["name"] == "test_exp" + assert "chain_sum" in config["datasets"] + + # Update dataset config + update_data = {"config": {"min_terms": 3, "max_terms": 5}} + response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers) + assert response.status_code == 200 + + # Verify update + response = client.get("/experiments/test_exp/composite", headers=headers) + assert response.status_code == 200 + config = response.json() + assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3 + assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5 + + # Test error cases + # Non-existent experiment + response = client.get("/experiments/nonexistent/composite", headers=headers) + assert response.status_code == 404 + + # Non-existent dataset + response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers) + assert response.status_code == 404 diff --git a/tools/server/tests/test_registry.py b/tools/server/tests/test_registry.py new file mode 100644 index 00000000..9e19df03 --- /dev/null +++ b/tools/server/tests/test_registry.py @@ -0,0 +1,44 @@ +"""Tests for experiment registry.""" + +import pytest + +from reasoning_gym.arithmetic.chain_sum import ChainSumConfig +from reasoning_gym.coaching.registry import ExperimentRegistry +from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec + + +def test_singleton(): + """Test that ExperimentRegistry is a singleton.""" + registry1 = ExperimentRegistry() + registry2 = ExperimentRegistry() + assert registry1 is registry2 + + +def test_experiment_management(): + """Test basic experiment management operations.""" + registry = ExperimentRegistry() + + # Clear any existing experiments + for name in registry.list_experiments(): + registry.remove_experiment(name) + + # Test registration with chain_sum dataset + chain_sum_spec = DatasetSpec(name="chain_sum", weight=1.0, config=vars(ChainSumConfig(size=10, seed=42))) + + config = CompositeConfig(size=10, seed=42, datasets=[chain_sum_spec]) + registry.register_experiment("test_exp", config) + + # Test listing + assert "test_exp" in registry.list_experiments() + + # Test retrieval + exp = registry.get_experiment("test_exp") + assert exp is not None + assert exp.name == "test_exp" + assert isinstance(exp.dataset, CompositeDataset) + assert exp.config == config + + # Test removal + assert registry.remove_experiment("test_exp") + assert "test_exp" not in registry.list_experiments() + assert not registry.remove_experiment("nonexistent")