add configuration option for ArcAgiDataset

This commit is contained in:
Andreas Koepf 2025-02-16 12:49:11 +01:00
parent 66ca1c3ace
commit 52b1fd1cae
2 changed files with 35 additions and 2 deletions

View file

@ -27,6 +27,7 @@ class ArcAgiConfig:
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
) # empty list for no mirrors
use_color_permutation: bool = True
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
seed: Optional[int] = None
size: int = 500
@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board:
return [[colors[c] for c in row] for row in board]
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
class ArcAgiDataset(ProceduralDataset):
@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset):
for p in train:
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
if self.config.shuffle_example_order:
rng.shuffle(augmented_train)
examples = [
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
for i, p in enumerate(augmented_train)