mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
add rotate, mirror & color-mapping augmentation functions
This commit is contained in:
parent
1209e9df72
commit
492570ff5c
1 changed files with 50 additions and 15 deletions
|
|
@ -27,6 +27,56 @@ class ArcAgiConfig:
|
|||
assert self.size > 0, "Size of dataset must be positive."
|
||||
|
||||
|
||||
Board = list[list[int]]
|
||||
|
||||
|
||||
def identity(board: Board) -> Board:
|
||||
return board
|
||||
|
||||
|
||||
def rot90(board: Board) -> Board:
|
||||
"""quarter clockwise rotation"""
|
||||
return [row for row in zip(*board[::-1])]
|
||||
|
||||
|
||||
def rot180(board: Board) -> Board:
|
||||
"""half rotation"""
|
||||
return [row[::-1] for row in board[::-1]]
|
||||
|
||||
|
||||
def rot270(board: Board) -> Board:
|
||||
"""quarter anticlockwise rotation"""
|
||||
return [row[::-1] for row in zip(*board[::-1])][::-1]
|
||||
|
||||
|
||||
def hmirror(board: Board) -> Board:
|
||||
"""mirroring along horizontal"""
|
||||
return board[::-1]
|
||||
|
||||
|
||||
def vmirror(board: Board) -> Board:
|
||||
"""mirroring along vertical"""
|
||||
return [row[::-1] for row in board]
|
||||
|
||||
|
||||
def dmirror(board: Board) -> Board:
|
||||
"""mirroring along diagonal"""
|
||||
return list(zip(*board))
|
||||
|
||||
|
||||
def cmirror(board: Board) -> Board:
|
||||
"""mirroring along counterdiagonal"""
|
||||
return list(zip(*[r[::-1] for r in board[::-1]]))
|
||||
|
||||
|
||||
def cmap(board: Board, colors: list[int]) -> Board:
|
||||
return [[colors[c] for c in row] for row in board]
|
||||
|
||||
|
||||
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||
|
||||
|
||||
class ArcAgiDataset(ProceduralDataset):
|
||||
def __init__(self, config: ArcAgiConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
|
@ -93,18 +143,3 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
|
||||
|
||||
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = ArcAgiConfig(seed=99)
|
||||
test = ArcAgiDataset(cfg)
|
||||
|
||||
x = test[1]
|
||||
|
||||
a = """1 6 7
|
||||
6 7 6
|
||||
2 2 6"""
|
||||
|
||||
print("q:", x["question"])
|
||||
print("a:", x["answer"])
|
||||
print("score:", test.score_answer(answer=a, entry=x))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue