mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-02 17:45:58 +00:00
add difficulty where possible (#274)
This commit is contained in:
parent
8790c6be00
commit
a48ff14507
6 changed files with 46 additions and 14 deletions
|
|
@ -42,7 +42,12 @@ class CountBitsDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": QUESTION_TEMPLATE.format(number=number),
|
"question": QUESTION_TEMPLATE.format(number=number),
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {"number": number, "solution": answer, "binary": binary},
|
"metadata": {
|
||||||
|
"number": number,
|
||||||
|
"solution": answer,
|
||||||
|
"binary": binary,
|
||||||
|
"difficulty": {"n": number},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,11 @@ class MahjongPuzzleDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": QUESTION_TEMPLATE.format(cards=cards, operations=operations),
|
"question": QUESTION_TEMPLATE.format(cards=cards, operations=operations),
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"metadata": {"rounds": rounds, "solution": answer},
|
"metadata": {
|
||||||
|
"rounds": rounds,
|
||||||
|
"solution": answer,
|
||||||
|
"difficulty": {"num_rounds": num_rounds},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,10 @@ class NQueensDataset(ProceduralDataset):
|
||||||
"solutions": valid_solutions,
|
"solutions": valid_solutions,
|
||||||
"num_removed": num_removed,
|
"num_removed": num_removed,
|
||||||
"valid_answers": valid_solutions_str,
|
"valid_answers": valid_solutions_str,
|
||||||
|
"difficulty": {
|
||||||
|
"n": self.config.n,
|
||||||
|
"num_removed": num_removed,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,13 @@ class CourseScheduleDataset(ProceduralDataset):
|
||||||
prerequisites=str(prerequisites),
|
prerequisites=str(prerequisites),
|
||||||
),
|
),
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable},
|
"metadata": {
|
||||||
|
"courses": courses,
|
||||||
|
"prerequisites": prerequisites,
|
||||||
|
"solution": answer,
|
||||||
|
"solvable": solvable,
|
||||||
|
"difficulty": {"num_courses": num_courses},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ class LargestIslandDataset(ProceduralDataset):
|
||||||
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
|
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
|
||||||
return 0 <= r < rows and 0 <= c < cols
|
return 0 <= r < rows and 0 <= c < cols
|
||||||
|
|
||||||
def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]:
|
def _create_grid(self, rng: Random, rows: int, cols: int, num_islands: int) -> list[list[int]]:
|
||||||
"""Create a random grid of islands using a random walk algorithm"""
|
"""Create a random grid of islands using a random walk algorithm"""
|
||||||
grid = [[0] * cols for _ in range(rows)]
|
grid = [[0] * cols for _ in range(rows)]
|
||||||
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
|
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
|
||||||
|
|
@ -78,7 +78,6 @@ class LargestIslandDataset(ProceduralDataset):
|
||||||
r, c = new_r, new_c
|
r, c = new_r, new_c
|
||||||
break
|
break
|
||||||
|
|
||||||
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
|
|
||||||
for _ in range(num_islands):
|
for _ in range(num_islands):
|
||||||
create_island()
|
create_island()
|
||||||
|
|
||||||
|
|
@ -130,7 +129,8 @@ class LargestIslandDataset(ProceduralDataset):
|
||||||
|
|
||||||
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||||
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||||
grid = self._create_grid(rng, rows, cols)
|
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
|
||||||
|
grid = self._create_grid(rng, rows, cols, num_islands)
|
||||||
grid_str = self._grid_to_string(grid)
|
grid_str = self._grid_to_string(grid)
|
||||||
|
|
||||||
answer = self._get_largest_island(grid)
|
answer = self._get_largest_island(grid)
|
||||||
|
|
@ -138,7 +138,15 @@ class LargestIslandDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
|
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {"grid": grid, "solution": answer},
|
"metadata": {
|
||||||
|
"grid": grid,
|
||||||
|
"solution": answer,
|
||||||
|
"difficulty": {
|
||||||
|
"rows": rows,
|
||||||
|
"cols": cols,
|
||||||
|
"num_islands": num_islands,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,12 +57,8 @@ class ShortestPathDataset(ProceduralDataset):
|
||||||
def __init__(self, config: ShortestPathConfig):
|
def __init__(self, config: ShortestPathConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
def _get_grid(self, rng: Random) -> list[list[str]]:
|
def _get_grid(self, rng: Random, rows: int, cols: int) -> list[list[str]]:
|
||||||
"""Generate a random grid with open and blocked cells"""
|
"""Generate a random grid with open and blocked cells"""
|
||||||
|
|
||||||
rows, cols = rng.randint(self.config.min_rows, self.config.max_rows), rng.randint(
|
|
||||||
self.config.min_cols, self.config.max_cols
|
|
||||||
)
|
|
||||||
grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)]
|
grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)]
|
||||||
|
|
||||||
start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
|
start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
|
||||||
|
|
@ -152,7 +148,9 @@ class ShortestPathDataset(ProceduralDataset):
|
||||||
"""Generate a single Shortest Path question"""
|
"""Generate a single Shortest Path question"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
matrix = self._get_grid(rng)
|
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||||
|
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||||
|
matrix = self._get_grid(rng, rows, cols)
|
||||||
matrix_str = self._matrix_to_str(matrix)
|
matrix_str = self._matrix_to_str(matrix)
|
||||||
answer = self._get_answer(matrix)
|
answer = self._get_answer(matrix)
|
||||||
answer_str = " ".join(answer) if answer else "infeasible"
|
answer_str = " ".join(answer) if answer else "infeasible"
|
||||||
|
|
@ -160,7 +158,14 @@ class ShortestPathDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": QUESTION_TEMPLATE.format(grid=matrix_str),
|
"question": QUESTION_TEMPLATE.format(grid=matrix_str),
|
||||||
"answer": answer_str,
|
"answer": answer_str,
|
||||||
"metadata": {"matrix": matrix, "solution": answer},
|
"metadata": {
|
||||||
|
"matrix": matrix,
|
||||||
|
"solution": answer,
|
||||||
|
"difficulty": {
|
||||||
|
"rows": rows,
|
||||||
|
"cols": cols,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue