mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add gol curriculum (#354)
* Add gol curriculum * Add difficulty * Make levels of grid size of x and y be valid
This commit is contained in:
parent
fa2b04f4de
commit
adea7a255e
3 changed files with 128 additions and 4 deletions
|
|
@ -5,6 +5,7 @@ from typing import Any, Optional
|
|||
|
||||
import cellpylib as cpl
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -14,7 +15,8 @@ class GameOfLifeConfig:
|
|||
|
||||
grid_size_x: int = 10
|
||||
grid_size_y: int = 10
|
||||
filled_cells: int = 100 # actually a max
|
||||
filled_cells_weights: float = 0.1
|
||||
filled_cells: int = int(filled_cells_weights * grid_size_x * grid_size_y) # actually a max
|
||||
simulation_steps: int = 1
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
|
@ -83,6 +85,12 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
"grid_size_y": self.config.grid_size_y,
|
||||
"filled_cells": self.config.filled_cells,
|
||||
"simulation_steps": self.config.simulation_steps,
|
||||
"difficulty": {
|
||||
"grid_size_x": self.config.grid_size_x,
|
||||
"grid_size_y": self.config.grid_size_y,
|
||||
"filled_cells_weights": self.config.filled_cells_weights,
|
||||
"simulation_steps": self.config.simulation_steps,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -143,4 +151,52 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
return correct_cells / total_cells
|
||||
|
||||
|
||||
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig)
|
||||
class GameOfLifeCurriculum(BaseCurriculum):
|
||||
"""Curriculum for Game of Life dataset"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(GameOfLifeCurriculum.__name__, GameOfLifeConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="grid_size_x",
|
||||
field_name="grid_size_x",
|
||||
levels=[10, 100, 500, 999],
|
||||
default_level=0,
|
||||
description="Grid size in the x direction",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=10,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="grid_size_y",
|
||||
field_name="grid_size_y",
|
||||
levels=[10, 100, 500, 999],
|
||||
default_level=0,
|
||||
description="Grid size in the y direction",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=-10,
|
||||
),
|
||||
# Filled cells should be 10%, 20%, 30%, 50% of the grid_size_x * grid_size_y
|
||||
ScalarAttributeDefinition(
|
||||
name="filled_cells_weights",
|
||||
field_name="filled_cells_weights",
|
||||
levels=[0.1, 0.2, 0.5, 0.8],
|
||||
default_level=0,
|
||||
description="Percentage of filled cells in the grid",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=0.1,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="simulation_steps",
|
||||
field_name="simulation_steps",
|
||||
levels=[1, 2, 5, 10],
|
||||
default_level=0,
|
||||
description="Number of simulation steps to run",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig, GameOfLifeCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue