mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
add configuration option for ArcAgiDataset
This commit is contained in:
parent
66ca1c3ace
commit
52b1fd1cae
2 changed files with 35 additions and 2 deletions
|
|
@ -27,6 +27,7 @@ class ArcAgiConfig:
|
|||
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
||||
) # empty list for no mirrors
|
||||
use_color_permutation: bool = True
|
||||
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
|
@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board:
|
|||
return [[colors[c] for c in row] for row in board]
|
||||
|
||||
|
||||
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||
|
||||
|
||||
class ArcAgiDataset(ProceduralDataset):
|
||||
|
|
@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
for p in train:
|
||||
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
||||
|
||||
if self.config.shuffle_example_order:
|
||||
rng.shuffle(augmented_train)
|
||||
|
||||
examples = [
|
||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
||||
for i, p in enumerate(augmented_train)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue