formatting

This commit is contained in:
Andreas Koepf 2025-02-09 00:04:42 +01:00
parent 8d8d85e6b2
commit e56316ebb2
2 changed files with 13 additions and 35 deletions

View file

@ -99,13 +99,13 @@ class ArcAgiDataset(ProceduralDataset):
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]: def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
"""Create a composite augmentation function from enabled options""" """Create a composite augmentation function from enabled options"""
fns = [] fns = []
if self.config.use_rotations: if self.config.use_rotations:
fns.append(rng.choice(ROTATION_AUGMENTATIONS)) fns.append(rng.choice(ROTATION_AUGMENTATIONS))
if self.config.use_mirrors: if self.config.use_mirrors:
fns.append(rng.choice(MIRROR_AUGMENTATIONS)) fns.append(rng.choice(MIRROR_AUGMENTATIONS))
if self.config.use_color_permutation: if self.config.use_color_permutation:
color_table = list(range(10)) color_table = list(range(10))
rng.shuffle(color_table) rng.shuffle(color_table)
@ -116,7 +116,7 @@ class ArcAgiDataset(ProceduralDataset):
for fn in fns: for fn in fns:
result = fn(result) result = fn(result)
return result return result
return composite_fn return composite_fn
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
@ -137,16 +137,14 @@ class ArcAgiDataset(ProceduralDataset):
# Apply augmentation to all train examples # Apply augmentation to all train examples
augmented_train = [] augmented_train = []
for p in train: for p in train:
augmented_train.append({ augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
"input": augment(p["input"]),
"output": augment(p["output"])
})
examples = [ examples = [
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
for i, p in enumerate(augmented_train) for i, p in enumerate(augmented_train)
] ]
examples = "".join(examples) examples = "".join(examples)
# Apply augmentation to test example # Apply augmentation to test example
augmented_test_input = augment(test["input"]) augmented_test_input = augment(test["input"])
augmented_test_output = augment(test["output"]) augmented_test_output = augment(test["output"])

View file

@ -52,48 +52,28 @@ def test_arc_agi_items():
def test_arc_agi_augmentations(): def test_arc_agi_augmentations():
"""Test that augmentations can be selectively enabled/disabled""" """Test that augmentations can be selectively enabled/disabled"""
# Test with all augmentations disabled # Test with all augmentations disabled
config = ArcAgiConfig( config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=False)
seed=42,
size=10,
use_rotations=False,
use_mirrors=False,
use_color_permutation=False
)
base_dataset = ArcAgiDataset(config) base_dataset = ArcAgiDataset(config)
base_items = list(base_dataset) base_items = list(base_dataset)
# Test with rotations only # Test with rotations only
rot_config = ArcAgiConfig( rot_config = ArcAgiConfig(seed=42, size=10, use_rotations=True, use_mirrors=False, use_color_permutation=False)
seed=42,
size=10,
use_rotations=True,
use_mirrors=False,
use_color_permutation=False
)
rot_dataset = ArcAgiDataset(rot_config) rot_dataset = ArcAgiDataset(rot_config)
rot_items = list(rot_dataset) rot_items = list(rot_dataset)
# Items should differ when rotations are enabled # Items should differ when rotations are enabled
assert any( assert any(
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items))
for i in range(len(base_items))
), "Rotation augmentation had no effect" ), "Rotation augmentation had no effect"
# Test with color permutation only # Test with color permutation only
color_config = ArcAgiConfig( color_config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=True)
seed=42,
size=10,
use_rotations=False,
use_mirrors=False,
use_color_permutation=True
)
color_dataset = ArcAgiDataset(color_config) color_dataset = ArcAgiDataset(color_config)
color_items = list(color_dataset) color_items = list(color_dataset)
# Items should differ when color permutation is enabled # Items should differ when color permutation is enabled
assert any( assert any(
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items))
for i in range(len(base_items))
), "Color permutation had no effect" ), "Color permutation had no effect"
@ -113,7 +93,7 @@ def test_arc_agi_scoring():
assert dataset.score_answer(None, entry=item) == 0.0 assert dataset.score_answer(None, entry=item) == 0.0
# Test wrong but valid grid format # Test wrong but valid grid format
wrong_answer = "0 0\n0 0" wrong_answer = "1 0 0 0\n0 0 0 1"
assert dataset.score_answer(wrong_answer, entry=item) == 0.05 assert dataset.score_answer(wrong_answer, entry=item) == 0.05