diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index f9578242..1e3d41da 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -27,6 +27,56 @@ class ArcAgiConfig: assert self.size > 0, "Size of dataset must be positive." +Board = list[list[int]] + + +def identity(board: Board) -> Board: + return board + + +def rot90(board: Board) -> Board: + """quarter clockwise rotation""" + return [row for row in zip(*board[::-1])] + + +def rot180(board: Board) -> Board: + """half rotation""" + return [row[::-1] for row in board[::-1]] + + +def rot270(board: Board) -> Board: + """quarter anticlockwise rotation""" + return [row[::-1] for row in zip(*board[::-1])][::-1] + + +def hmirror(board: Board) -> Board: + """mirroring along horizontal""" + return board[::-1] + + +def vmirror(board: Board) -> Board: + """mirroring along vertical""" + return [row[::-1] for row in board] + + +def dmirror(board: Board) -> Board: + """mirroring along diagonal""" + return list(zip(*board)) + + +def cmirror(board: Board) -> Board: + """mirroring along counterdiagonal""" + return list(zip(*[r[::-1] for r in board[::-1]])) + + +def cmap(board: Board, colors: list[int]) -> Board: + return [[colors[c] for c in row] for row in board] + + +ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] +MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] + + class ArcAgiDataset(ProceduralDataset): def __init__(self, config: ArcAgiConfig): super().__init__(config=config, seed=config.seed, size=config.size) @@ -93,18 +143,3 @@ class ArcAgiDataset(ProceduralDataset): register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig) - - -if __name__ == "__main__": - cfg = ArcAgiConfig(seed=99) - test = ArcAgiDataset(cfg) - - x = test[1] - - a = """1 6 7 -6 7 6 -2 2 6""" - - print("q:", x["question"]) - print("a:", x["answer"]) - print("score:", test.score_answer(answer=a, entry=x))