mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
add difficulty where possible (#274)
This commit is contained in:
parent
fb06038e88
commit
b915565c0d
6 changed files with 46 additions and 14 deletions
|
|
@ -42,7 +42,12 @@ class CountBitsDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": QUESTION_TEMPLATE.format(number=number),
|
||||
"answer": str(answer),
|
||||
"metadata": {"number": number, "solution": answer, "binary": binary},
|
||||
"metadata": {
|
||||
"number": number,
|
||||
"solution": answer,
|
||||
"binary": binary,
|
||||
"difficulty": {"n": number},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,11 @@ class MahjongPuzzleDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": QUESTION_TEMPLATE.format(cards=cards, operations=operations),
|
||||
"answer": answer,
|
||||
"metadata": {"rounds": rounds, "solution": answer},
|
||||
"metadata": {
|
||||
"rounds": rounds,
|
||||
"solution": answer,
|
||||
"difficulty": {"num_rounds": num_rounds},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,10 @@ class NQueensDataset(ProceduralDataset):
|
|||
"solutions": valid_solutions,
|
||||
"num_removed": num_removed,
|
||||
"valid_answers": valid_solutions_str,
|
||||
"difficulty": {
|
||||
"n": self.config.n,
|
||||
"num_removed": num_removed,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -131,7 +131,13 @@ class CourseScheduleDataset(ProceduralDataset):
|
|||
prerequisites=str(prerequisites),
|
||||
),
|
||||
"answer": str(answer),
|
||||
"metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable},
|
||||
"metadata": {
|
||||
"courses": courses,
|
||||
"prerequisites": prerequisites,
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
"difficulty": {"num_courses": num_courses},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
|
||||
return 0 <= r < rows and 0 <= c < cols
|
||||
|
||||
def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]:
|
||||
def _create_grid(self, rng: Random, rows: int, cols: int, num_islands: int) -> list[list[int]]:
|
||||
"""Create a random grid of islands using a random walk algorithm"""
|
||||
grid = [[0] * cols for _ in range(rows)]
|
||||
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
|
||||
|
|
@ -78,7 +78,6 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
r, c = new_r, new_c
|
||||
break
|
||||
|
||||
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
|
||||
for _ in range(num_islands):
|
||||
create_island()
|
||||
|
||||
|
|
@ -130,7 +129,8 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
|
||||
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||
grid = self._create_grid(rng, rows, cols)
|
||||
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
|
||||
grid = self._create_grid(rng, rows, cols, num_islands)
|
||||
grid_str = self._grid_to_string(grid)
|
||||
|
||||
answer = self._get_largest_island(grid)
|
||||
|
|
@ -138,7 +138,15 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
|
||||
"answer": str(answer),
|
||||
"metadata": {"grid": grid, "solution": answer},
|
||||
"metadata": {
|
||||
"grid": grid,
|
||||
"solution": answer,
|
||||
"difficulty": {
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
"num_islands": num_islands,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -57,12 +57,8 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
def __init__(self, config: ShortestPathConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _get_grid(self, rng: Random) -> list[list[str]]:
|
||||
def _get_grid(self, rng: Random, rows: int, cols: int) -> list[list[str]]:
|
||||
"""Generate a random grid with open and blocked cells"""
|
||||
|
||||
rows, cols = rng.randint(self.config.min_rows, self.config.max_rows), rng.randint(
|
||||
self.config.min_cols, self.config.max_cols
|
||||
)
|
||||
grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)]
|
||||
|
||||
start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
|
||||
|
|
@ -152,7 +148,9 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
"""Generate a single Shortest Path question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
matrix = self._get_grid(rng)
|
||||
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||
matrix = self._get_grid(rng, rows, cols)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
answer = self._get_answer(matrix)
|
||||
answer_str = " ".join(answer) if answer else "infeasible"
|
||||
|
|
@ -160,7 +158,14 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": QUESTION_TEMPLATE.format(grid=matrix_str),
|
||||
"answer": answer_str,
|
||||
"metadata": {"matrix": matrix, "solution": answer},
|
||||
"metadata": {
|
||||
"matrix": matrix,
|
||||
"solution": answer,
|
||||
"difficulty": {
|
||||
"rows": rows,
|
||||
"cols": cols,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue