diff --git a/reasoning_gym/arithmetic/count_bits.py b/reasoning_gym/arithmetic/count_bits.py index cd55bdc7..9a386af6 100644 --- a/reasoning_gym/arithmetic/count_bits.py +++ b/reasoning_gym/arithmetic/count_bits.py @@ -42,7 +42,12 @@ class CountBitsDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(number=number), "answer": str(answer), - "metadata": {"number": number, "solution": answer, "binary": binary}, + "metadata": { + "number": number, + "solution": answer, + "binary": binary, + "difficulty": {"n": number}, + }, } diff --git a/reasoning_gym/games/mahjong.py b/reasoning_gym/games/mahjong.py index cbd653db..79576e36 100644 --- a/reasoning_gym/games/mahjong.py +++ b/reasoning_gym/games/mahjong.py @@ -119,7 +119,11 @@ class MahjongPuzzleDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(cards=cards, operations=operations), "answer": answer, - "metadata": {"rounds": rounds, "solution": answer}, + "metadata": { + "rounds": rounds, + "solution": answer, + "difficulty": {"num_rounds": num_rounds}, + }, } diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py index fe43583e..a918b136 100644 --- a/reasoning_gym/games/n_queens.py +++ b/reasoning_gym/games/n_queens.py @@ -135,6 +135,10 @@ class NQueensDataset(ProceduralDataset): "solutions": valid_solutions, "num_removed": num_removed, "valid_answers": valid_solutions_str, + "difficulty": { + "n": self.config.n, + "num_removed": num_removed, + }, }, } diff --git a/reasoning_gym/graphs/course_schedule.py b/reasoning_gym/graphs/course_schedule.py index cf25a786..4a555e32 100644 --- a/reasoning_gym/graphs/course_schedule.py +++ b/reasoning_gym/graphs/course_schedule.py @@ -131,7 +131,13 @@ class CourseScheduleDataset(ProceduralDataset): prerequisites=str(prerequisites), ), "answer": str(answer), - "metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable}, + "metadata": { + "courses": courses, + "prerequisites": prerequisites, + "solution": answer, + "solvable": solvable, + "difficulty": {"num_courses": num_courses}, + }, } diff --git a/reasoning_gym/graphs/largest_island.py b/reasoning_gym/graphs/largest_island.py index 17826a67..7090497d 100644 --- a/reasoning_gym/graphs/largest_island.py +++ b/reasoning_gym/graphs/largest_island.py @@ -61,7 +61,7 @@ class LargestIslandDataset(ProceduralDataset): def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool: return 0 <= r < rows and 0 <= c < cols - def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]: + def _create_grid(self, rng: Random, rows: int, cols: int, num_islands: int) -> list[list[int]]: """Create a random grid of islands using a random walk algorithm""" grid = [[0] * cols for _ in range(rows)] directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right @@ -78,7 +78,6 @@ class LargestIslandDataset(ProceduralDataset): r, c = new_r, new_c break - num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands) for _ in range(num_islands): create_island() @@ -130,7 +129,8 @@ class LargestIslandDataset(ProceduralDataset): rows = rng.randint(self.config.min_rows, self.config.max_rows) cols = rng.randint(self.config.min_cols, self.config.max_cols) - grid = self._create_grid(rng, rows, cols) + num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands) + grid = self._create_grid(rng, rows, cols, num_islands) grid_str = self._grid_to_string(grid) answer = self._get_largest_island(grid) @@ -138,7 +138,15 @@ class LargestIslandDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str), "answer": str(answer), - "metadata": {"grid": grid, "solution": answer}, + "metadata": { + "grid": grid, + "solution": answer, + "difficulty": { + "rows": rows, + "cols": cols, + "num_islands": num_islands, + }, + }, } diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index bcf40a2b..d083d1a7 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -57,12 +57,8 @@ class ShortestPathDataset(ProceduralDataset): def __init__(self, config: ShortestPathConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _get_grid(self, rng: Random) -> list[list[str]]: + def _get_grid(self, rng: Random, rows: int, cols: int) -> list[list[str]]: """Generate a random grid with open and blocked cells""" - - rows, cols = rng.randint(self.config.min_rows, self.config.max_rows), rng.randint( - self.config.min_cols, self.config.max_cols - ) grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)] start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1) @@ -152,7 +148,9 @@ class ShortestPathDataset(ProceduralDataset): """Generate a single Shortest Path question""" rng = Random(self.seed + idx) - matrix = self._get_grid(rng) + rows = rng.randint(self.config.min_rows, self.config.max_rows) + cols = rng.randint(self.config.min_cols, self.config.max_cols) + matrix = self._get_grid(rng, rows, cols) matrix_str = self._matrix_to_str(matrix) answer = self._get_answer(matrix) answer_str = " ".join(answer) if answer else "infeasible" @@ -160,7 +158,14 @@ class ShortestPathDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(grid=matrix_str), "answer": answer_str, - "metadata": {"matrix": matrix, "solution": answer}, + "metadata": { + "matrix": matrix, + "solution": answer, + "difficulty": { + "rows": rows, + "cols": cols, + }, + }, }