mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
feat: Add configurable rotation and mirror augmentation variants
This commit is contained in:
parent
b73040b066
commit
ec8036c099
3 changed files with 486 additions and 19 deletions
|
|
@ -22,15 +22,23 @@ class ArcAgiConfig:
|
|||
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
||||
|
||||
# Augmentation options
|
||||
use_rotations: bool = True
|
||||
use_mirrors: bool = True
|
||||
rotations: list[str] = field(default_factory=lambda: ["90", "180", "270"]) # empty list for no rotations
|
||||
mirrors: list[str] = field(
|
||||
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
||||
) # empty list for no mirrors
|
||||
use_color_permutation: bool = True
|
||||
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
assert self.size > 0, "Size of dataset must be positive."
|
||||
valid_rotations = ["90", "180", "270"]
|
||||
valid_mirrors = ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
||||
for rot in self.rotations:
|
||||
assert rot in valid_rotations, f"Invalid rotation option: {rot}"
|
||||
for mirror in self.mirrors:
|
||||
assert mirror in valid_mirrors, f"Invalid mirror option: {mirror}"
|
||||
|
||||
|
||||
Board = list[list[int]]
|
||||
|
|
@ -103,11 +111,17 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
"""Create a composite augmentation function from enabled options"""
|
||||
fns = []
|
||||
|
||||
if self.config.use_rotations:
|
||||
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
|
||||
# Map rotation strings to functions
|
||||
rotation_map = {"90": rot90, "180": rot180, "270": rot270}
|
||||
if self.config.rotations:
|
||||
chosen_rot = rng.choice([identity] + [rotation_map[r] for r in self.config.rotations])
|
||||
fns.append(chosen_rot)
|
||||
|
||||
if self.config.use_mirrors:
|
||||
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
|
||||
# Map mirror strings to functions
|
||||
mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror}
|
||||
if self.config.mirrors:
|
||||
chosen_mirror = rng.choice([identity] + [mirror_map[m] for m in self.config.mirrors])
|
||||
fns.append(chosen_mirror)
|
||||
|
||||
if self.config.use_color_permutation:
|
||||
color_table = list(range(10))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue