add difficulty where possible (#274)

This commit is contained in:
Zafir Stojanovski 2025-03-07 19:01:26 +01:00 committed by GitHub
parent fb06038e88
commit b915565c0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 46 additions and 14 deletions

View file

@ -42,7 +42,12 @@ class CountBitsDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(number=number),
"answer": str(answer),
"metadata": {"number": number, "solution": answer, "binary": binary},
"metadata": {
"number": number,
"solution": answer,
"binary": binary,
"difficulty": {"n": number},
},
}

View file

@ -119,7 +119,11 @@ class MahjongPuzzleDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(cards=cards, operations=operations),
"answer": answer,
"metadata": {"rounds": rounds, "solution": answer},
"metadata": {
"rounds": rounds,
"solution": answer,
"difficulty": {"num_rounds": num_rounds},
},
}

View file

@ -135,6 +135,10 @@ class NQueensDataset(ProceduralDataset):
"solutions": valid_solutions,
"num_removed": num_removed,
"valid_answers": valid_solutions_str,
"difficulty": {
"n": self.config.n,
"num_removed": num_removed,
},
},
}

View file

@ -131,7 +131,13 @@ class CourseScheduleDataset(ProceduralDataset):
prerequisites=str(prerequisites),
),
"answer": str(answer),
"metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable},
"metadata": {
"courses": courses,
"prerequisites": prerequisites,
"solution": answer,
"solvable": solvable,
"difficulty": {"num_courses": num_courses},
},
}

View file

@ -61,7 +61,7 @@ class LargestIslandDataset(ProceduralDataset):
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
return 0 <= r < rows and 0 <= c < cols
def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]:
def _create_grid(self, rng: Random, rows: int, cols: int, num_islands: int) -> list[list[int]]:
"""Create a random grid of islands using a random walk algorithm"""
grid = [[0] * cols for _ in range(rows)]
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
@ -78,7 +78,6 @@ class LargestIslandDataset(ProceduralDataset):
r, c = new_r, new_c
break
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
for _ in range(num_islands):
create_island()
@ -130,7 +129,8 @@ class LargestIslandDataset(ProceduralDataset):
rows = rng.randint(self.config.min_rows, self.config.max_rows)
cols = rng.randint(self.config.min_cols, self.config.max_cols)
grid = self._create_grid(rng, rows, cols)
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
grid = self._create_grid(rng, rows, cols, num_islands)
grid_str = self._grid_to_string(grid)
answer = self._get_largest_island(grid)
@ -138,7 +138,15 @@ class LargestIslandDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
"answer": str(answer),
"metadata": {"grid": grid, "solution": answer},
"metadata": {
"grid": grid,
"solution": answer,
"difficulty": {
"rows": rows,
"cols": cols,
"num_islands": num_islands,
},
},
}

View file

@ -57,12 +57,8 @@ class ShortestPathDataset(ProceduralDataset):
def __init__(self, config: ShortestPathConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_grid(self, rng: Random) -> list[list[str]]:
def _get_grid(self, rng: Random, rows: int, cols: int) -> list[list[str]]:
"""Generate a random grid with open and blocked cells"""
rows, cols = rng.randint(self.config.min_rows, self.config.max_rows), rng.randint(
self.config.min_cols, self.config.max_cols
)
grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)]
start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
@ -152,7 +148,9 @@ class ShortestPathDataset(ProceduralDataset):
"""Generate a single Shortest Path question"""
rng = Random(self.seed + idx)
matrix = self._get_grid(rng)
rows = rng.randint(self.config.min_rows, self.config.max_rows)
cols = rng.randint(self.config.min_cols, self.config.max_cols)
matrix = self._get_grid(rng, rows, cols)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_answer(matrix)
answer_str = " ".join(answer) if answer else "infeasible"
@ -160,7 +158,14 @@ class ShortestPathDataset(ProceduralDataset):
return {
"question": QUESTION_TEMPLATE.format(grid=matrix_str),
"answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer},
"metadata": {
"matrix": matrix,
"solution": answer,
"difficulty": {
"rows": rows,
"cols": cols,
},
},
}