mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
formatting
This commit is contained in:
parent
8d8d85e6b2
commit
e56316ebb2
2 changed files with 13 additions and 35 deletions
|
|
@ -99,13 +99,13 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
|
||||
"""Create a composite augmentation function from enabled options"""
|
||||
fns = []
|
||||
|
||||
|
||||
if self.config.use_rotations:
|
||||
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
|
||||
|
||||
|
||||
if self.config.use_mirrors:
|
||||
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
|
||||
|
||||
|
||||
if self.config.use_color_permutation:
|
||||
color_table = list(range(10))
|
||||
rng.shuffle(color_table)
|
||||
|
|
@ -116,7 +116,7 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
for fn in fns:
|
||||
result = fn(result)
|
||||
return result
|
||||
|
||||
|
||||
return composite_fn
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
@ -137,16 +137,14 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
# Apply augmentation to all train examples
|
||||
augmented_train = []
|
||||
for p in train:
|
||||
augmented_train.append({
|
||||
"input": augment(p["input"]),
|
||||
"output": augment(p["output"])
|
||||
})
|
||||
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
||||
|
||||
examples = [
|
||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
||||
for i, p in enumerate(augmented_train)
|
||||
]
|
||||
examples = "".join(examples)
|
||||
|
||||
# Apply augmentation to test example
|
||||
augmented_test_input = augment(test["input"])
|
||||
augmented_test_output = augment(test["output"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue