mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
formatting
This commit is contained in:
parent
8d8d85e6b2
commit
e56316ebb2
2 changed files with 13 additions and 35 deletions
|
|
@ -99,13 +99,13 @@ class ArcAgiDataset(ProceduralDataset):
|
||||||
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
|
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
|
||||||
"""Create a composite augmentation function from enabled options"""
|
"""Create a composite augmentation function from enabled options"""
|
||||||
fns = []
|
fns = []
|
||||||
|
|
||||||
if self.config.use_rotations:
|
if self.config.use_rotations:
|
||||||
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
|
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
|
||||||
|
|
||||||
if self.config.use_mirrors:
|
if self.config.use_mirrors:
|
||||||
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
|
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
|
||||||
|
|
||||||
if self.config.use_color_permutation:
|
if self.config.use_color_permutation:
|
||||||
color_table = list(range(10))
|
color_table = list(range(10))
|
||||||
rng.shuffle(color_table)
|
rng.shuffle(color_table)
|
||||||
|
|
@ -116,7 +116,7 @@ class ArcAgiDataset(ProceduralDataset):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
result = fn(result)
|
result = fn(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return composite_fn
|
return composite_fn
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
|
@ -137,16 +137,14 @@ class ArcAgiDataset(ProceduralDataset):
|
||||||
# Apply augmentation to all train examples
|
# Apply augmentation to all train examples
|
||||||
augmented_train = []
|
augmented_train = []
|
||||||
for p in train:
|
for p in train:
|
||||||
augmented_train.append({
|
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
||||||
"input": augment(p["input"]),
|
|
||||||
"output": augment(p["output"])
|
|
||||||
})
|
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
||||||
for i, p in enumerate(augmented_train)
|
for i, p in enumerate(augmented_train)
|
||||||
]
|
]
|
||||||
examples = "".join(examples)
|
examples = "".join(examples)
|
||||||
|
|
||||||
# Apply augmentation to test example
|
# Apply augmentation to test example
|
||||||
augmented_test_input = augment(test["input"])
|
augmented_test_input = augment(test["input"])
|
||||||
augmented_test_output = augment(test["output"])
|
augmented_test_output = augment(test["output"])
|
||||||
|
|
|
||||||
|
|
@ -52,48 +52,28 @@ def test_arc_agi_items():
|
||||||
def test_arc_agi_augmentations():
|
def test_arc_agi_augmentations():
|
||||||
"""Test that augmentations can be selectively enabled/disabled"""
|
"""Test that augmentations can be selectively enabled/disabled"""
|
||||||
# Test with all augmentations disabled
|
# Test with all augmentations disabled
|
||||||
config = ArcAgiConfig(
|
config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=False)
|
||||||
seed=42,
|
|
||||||
size=10,
|
|
||||||
use_rotations=False,
|
|
||||||
use_mirrors=False,
|
|
||||||
use_color_permutation=False
|
|
||||||
)
|
|
||||||
base_dataset = ArcAgiDataset(config)
|
base_dataset = ArcAgiDataset(config)
|
||||||
base_items = list(base_dataset)
|
base_items = list(base_dataset)
|
||||||
|
|
||||||
# Test with rotations only
|
# Test with rotations only
|
||||||
rot_config = ArcAgiConfig(
|
rot_config = ArcAgiConfig(seed=42, size=10, use_rotations=True, use_mirrors=False, use_color_permutation=False)
|
||||||
seed=42,
|
|
||||||
size=10,
|
|
||||||
use_rotations=True,
|
|
||||||
use_mirrors=False,
|
|
||||||
use_color_permutation=False
|
|
||||||
)
|
|
||||||
rot_dataset = ArcAgiDataset(rot_config)
|
rot_dataset = ArcAgiDataset(rot_config)
|
||||||
rot_items = list(rot_dataset)
|
rot_items = list(rot_dataset)
|
||||||
|
|
||||||
# Items should differ when rotations are enabled
|
# Items should differ when rotations are enabled
|
||||||
assert any(
|
assert any(
|
||||||
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"]
|
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items))
|
||||||
for i in range(len(base_items))
|
|
||||||
), "Rotation augmentation had no effect"
|
), "Rotation augmentation had no effect"
|
||||||
|
|
||||||
# Test with color permutation only
|
# Test with color permutation only
|
||||||
color_config = ArcAgiConfig(
|
color_config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=True)
|
||||||
seed=42,
|
|
||||||
size=10,
|
|
||||||
use_rotations=False,
|
|
||||||
use_mirrors=False,
|
|
||||||
use_color_permutation=True
|
|
||||||
)
|
|
||||||
color_dataset = ArcAgiDataset(color_config)
|
color_dataset = ArcAgiDataset(color_config)
|
||||||
color_items = list(color_dataset)
|
color_items = list(color_dataset)
|
||||||
|
|
||||||
# Items should differ when color permutation is enabled
|
# Items should differ when color permutation is enabled
|
||||||
assert any(
|
assert any(
|
||||||
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"]
|
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items))
|
||||||
for i in range(len(base_items))
|
|
||||||
), "Color permutation had no effect"
|
), "Color permutation had no effect"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,7 +93,7 @@ def test_arc_agi_scoring():
|
||||||
assert dataset.score_answer(None, entry=item) == 0.0
|
assert dataset.score_answer(None, entry=item) == 0.0
|
||||||
|
|
||||||
# Test wrong but valid grid format
|
# Test wrong but valid grid format
|
||||||
wrong_answer = "0 0\n0 0"
|
wrong_answer = "1 0 0 0\n0 0 0 1"
|
||||||
assert dataset.score_answer(wrong_answer, entry=item) == 0.05
|
assert dataset.score_answer(wrong_answer, entry=item) == 0.05
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue