feat: Add configurable rotation and mirror augmentation variants

This commit is contained in:
Andreas Koepf (aider) 2025-02-09 00:14:43 +01:00 committed by Andreas Koepf
parent b73040b066
commit ec8036c099
3 changed files with 486 additions and 19 deletions

View file

@ -22,15 +22,23 @@ class ArcAgiConfig:
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
# Augmentation options
use_rotations: bool = True
use_mirrors: bool = True
rotations: list[str] = field(default_factory=lambda: ["90", "180", "270"]) # empty list for no rotations
mirrors: list[str] = field(
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
) # empty list for no mirrors
use_color_permutation: bool = True
seed: Optional[int] = None
size: int = 500
def validate(self):
assert self.size > 0, "Size of dataset must be positive."
valid_rotations = ["90", "180", "270"]
valid_mirrors = ["horizontal", "vertical", "diagonal", "counterdiagonal"]
for rot in self.rotations:
assert rot in valid_rotations, f"Invalid rotation option: {rot}"
for mirror in self.mirrors:
assert mirror in valid_mirrors, f"Invalid mirror option: {mirror}"
Board = list[list[int]]
@ -103,11 +111,17 @@ class ArcAgiDataset(ProceduralDataset):
"""Create a composite augmentation function from enabled options"""
fns = []
if self.config.use_rotations:
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
# Map rotation strings to functions
rotation_map = {"90": rot90, "180": rot180, "270": rot270}
if self.config.rotations:
chosen_rot = rng.choice([identity] + [rotation_map[r] for r in self.config.rotations])
fns.append(chosen_rot)
if self.config.use_mirrors:
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
# Map mirror strings to functions
mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror}
if self.config.mirrors:
chosen_mirror = rng.choice([identity] + [mirror_map[m] for m in self.config.mirrors])
fns.append(chosen_mirror)
if self.config.use_color_permutation:
color_table = list(range(10))