add rotate, mirror & color-mapping augmentation functions

This commit is contained in:
Andreas Koepf 2025-02-08 23:51:38 +01:00
parent 1209e9df72
commit 492570ff5c

View file

@ -27,6 +27,56 @@ class ArcAgiConfig:
assert self.size > 0, "Size of dataset must be positive."
Board = list[list[int]]
def identity(board: Board) -> Board:
return board
def rot90(board: Board) -> Board:
"""quarter clockwise rotation"""
return [row for row in zip(*board[::-1])]
def rot180(board: Board) -> Board:
"""half rotation"""
return [row[::-1] for row in board[::-1]]
def rot270(board: Board) -> Board:
"""quarter anticlockwise rotation"""
return [row[::-1] for row in zip(*board[::-1])][::-1]
def hmirror(board: Board) -> Board:
"""mirroring along horizontal"""
return board[::-1]
def vmirror(board: Board) -> Board:
"""mirroring along vertical"""
return [row[::-1] for row in board]
def dmirror(board: Board) -> Board:
"""mirroring along diagonal"""
return list(zip(*board))
def cmirror(board: Board) -> Board:
"""mirroring along counterdiagonal"""
return list(zip(*[r[::-1] for r in board[::-1]]))
def cmap(board: Board, colors: list[int]) -> Board:
return [[colors[c] for c in row] for row in board]
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
class ArcAgiDataset(ProceduralDataset):
def __init__(self, config: ArcAgiConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
@ -93,18 +143,3 @@ class ArcAgiDataset(ProceduralDataset):
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
if __name__ == "__main__":
cfg = ArcAgiConfig(seed=99)
test = ArcAgiDataset(cfg)
x = test[1]
a = """1 6 7
6 7 6
2 2 6"""
print("q:", x["question"])
print("a:", x["answer"])
print("score:", test.score_answer(answer=a, entry=x))