mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
more native type hints
This commit is contained in:
parent
90a1181285
commit
eeb9fa31d5
19 changed files with 90 additions and 92 deletions
|
|
@ -1,23 +1,23 @@
|
|||
# types
|
||||
|
||||
|
||||
from typing import Any, Callable, Container, FrozenSet, Tuple, Union
|
||||
from typing import Any, Callable, Container, FrozenSet, Union
|
||||
|
||||
Boolean = bool
|
||||
Integer = int
|
||||
IntegerTuple = Tuple[Integer, Integer]
|
||||
Numerical = Union[Integer, IntegerTuple]
|
||||
Integertuple = tuple[Integer, Integer]
|
||||
Numerical = Union[Integer, Integertuple]
|
||||
IntegerSet = FrozenSet[Integer]
|
||||
Grid = Tuple[Tuple[Integer]]
|
||||
Cell = Tuple[Integer, IntegerTuple]
|
||||
Grid = tuple[tuple[Integer]]
|
||||
Cell = tuple[Integer, Integertuple]
|
||||
Object = FrozenSet[Cell]
|
||||
Objects = FrozenSet[Object]
|
||||
Indices = FrozenSet[IntegerTuple]
|
||||
Indices = FrozenSet[Integertuple]
|
||||
IndicesSet = FrozenSet[Indices]
|
||||
Patch = Union[Object, Indices]
|
||||
Element = Union[Object, Grid]
|
||||
Piece = Union[Grid, Patch]
|
||||
TupleTuple = Tuple[Tuple]
|
||||
tupletuple = tuple[tuple]
|
||||
ContainerContainer = Container[Container]
|
||||
|
||||
|
||||
|
|
@ -160,17 +160,17 @@ def difference(a: Container, b: Container) -> Container:
|
|||
return type(a)(e for e in a if e not in b)
|
||||
|
||||
|
||||
def dedupe(iterable: Tuple) -> Tuple:
|
||||
def dedupe(iterable: tuple) -> tuple:
|
||||
"""remove duplicates"""
|
||||
return tuple(e for i, e in enumerate(iterable) if iterable.index(e) == i)
|
||||
|
||||
|
||||
def order(container: Container, compfunc: Callable) -> Tuple:
|
||||
def order(container: Container, compfunc: Callable) -> tuple:
|
||||
"""order container by custom key"""
|
||||
return tuple(sorted(container, key=compfunc))
|
||||
|
||||
|
||||
def repeat(item: Any, num: Integer) -> Tuple:
|
||||
def repeat(item: Any, num: Integer) -> tuple:
|
||||
"""repetition of item within vector"""
|
||||
return tuple(item for i in range(num))
|
||||
|
||||
|
|
@ -277,12 +277,12 @@ def positive(x: Integer) -> Boolean:
|
|||
return x > 0
|
||||
|
||||
|
||||
def toivec(i: Integer) -> IntegerTuple:
|
||||
def toivec(i: Integer) -> Integertuple:
|
||||
"""vector pointing vertically"""
|
||||
return (i, 0)
|
||||
|
||||
|
||||
def tojvec(j: Integer) -> IntegerTuple:
|
||||
def tojvec(j: Integer) -> Integertuple:
|
||||
"""vector pointing horizontally"""
|
||||
return (0, j)
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ def extract(container: Container, condition: Callable) -> Any:
|
|||
return next(e for e in container if condition(e))
|
||||
|
||||
|
||||
def totuple(container: FrozenSet) -> Tuple:
|
||||
def totuple(container: FrozenSet) -> tuple:
|
||||
"""conversion to tuple"""
|
||||
return tuple(container)
|
||||
|
||||
|
|
@ -332,12 +332,12 @@ def other(container: Container, value: Any) -> Any:
|
|||
return first(remove(value, container))
|
||||
|
||||
|
||||
def interval(start: Integer, stop: Integer, step: Integer) -> Tuple:
|
||||
def interval(start: Integer, stop: Integer, step: Integer) -> tuple:
|
||||
"""range"""
|
||||
return tuple(range(start, stop, step))
|
||||
|
||||
|
||||
def astuple(a: Integer, b: Integer) -> IntegerTuple:
|
||||
def astuple(a: Integer, b: Integer) -> Integertuple:
|
||||
"""constructs a tuple"""
|
||||
return (a, b)
|
||||
|
||||
|
|
@ -347,7 +347,7 @@ def product(a: Container, b: Container) -> FrozenSet:
|
|||
return frozenset((i, j) for j in b for i in a)
|
||||
|
||||
|
||||
def pair(a: Tuple, b: Tuple) -> TupleTuple:
|
||||
def pair(a: tuple, b: tuple) -> tupletuple:
|
||||
"""zipping of two tuples"""
|
||||
return tuple(zip(a, b))
|
||||
|
||||
|
|
@ -421,12 +421,12 @@ def mapply(function: Callable, container: ContainerContainer) -> FrozenSet:
|
|||
return merge(apply(function, container))
|
||||
|
||||
|
||||
def papply(function: Callable, a: Tuple, b: Tuple) -> Tuple:
|
||||
def papply(function: Callable, a: tuple, b: tuple) -> tuple:
|
||||
"""apply function on two vectors"""
|
||||
return tuple(function(i, j) for i, j in zip(a, b))
|
||||
|
||||
|
||||
def mpapply(function: Callable, a: Tuple, b: Tuple) -> Tuple:
|
||||
def mpapply(function: Callable, a: tuple, b: tuple) -> tuple:
|
||||
"""apply function on two vectors and merge"""
|
||||
return merge(papply(function, a, b))
|
||||
|
||||
|
|
@ -466,7 +466,7 @@ def width(piece: Piece) -> Integer:
|
|||
return rightmost(piece) - leftmost(piece) + 1
|
||||
|
||||
|
||||
def shape(piece: Piece) -> IntegerTuple:
|
||||
def shape(piece: Piece) -> Integertuple:
|
||||
"""height and width of grid or patch"""
|
||||
return (height(piece), width(piece))
|
||||
|
||||
|
|
@ -503,27 +503,27 @@ def ofcolor(grid: Grid, value: Integer) -> Indices:
|
|||
return frozenset((i, j) for i, r in enumerate(grid) for j, v in enumerate(r) if v == value)
|
||||
|
||||
|
||||
def ulcorner(patch: Patch) -> IntegerTuple:
|
||||
def ulcorner(patch: Patch) -> Integertuple:
|
||||
"""index of upper left corner"""
|
||||
return tuple(map(min, zip(*toindices(patch))))
|
||||
|
||||
|
||||
def urcorner(patch: Patch) -> IntegerTuple:
|
||||
def urcorner(patch: Patch) -> Integertuple:
|
||||
"""index of upper right corner"""
|
||||
return tuple(map(lambda ix: {0: min, 1: max}[ix[0]](ix[1]), enumerate(zip(*toindices(patch)))))
|
||||
|
||||
|
||||
def llcorner(patch: Patch) -> IntegerTuple:
|
||||
def llcorner(patch: Patch) -> Integertuple:
|
||||
"""index of lower left corner"""
|
||||
return tuple(map(lambda ix: {0: max, 1: min}[ix[0]](ix[1]), enumerate(zip(*toindices(patch)))))
|
||||
|
||||
|
||||
def lrcorner(patch: Patch) -> IntegerTuple:
|
||||
def lrcorner(patch: Patch) -> Integertuple:
|
||||
"""index of lower right corner"""
|
||||
return tuple(map(max, zip(*toindices(patch))))
|
||||
|
||||
|
||||
def crop(grid: Grid, start: IntegerTuple, dims: IntegerTuple) -> Grid:
|
||||
def crop(grid: Grid, start: Integertuple, dims: Integertuple) -> Grid:
|
||||
"""subgrid specified by start and dimension"""
|
||||
return tuple(r[start[1] : start[1] + dims[1]] for r in grid[start[0] : start[0] + dims[0]])
|
||||
|
||||
|
|
@ -542,7 +542,7 @@ def recolor(value: Integer, patch: Patch) -> Object:
|
|||
return frozenset((value, index) for index in toindices(patch))
|
||||
|
||||
|
||||
def shift(patch: Patch, directions: IntegerTuple) -> Patch:
|
||||
def shift(patch: Patch, directions: Integertuple) -> Patch:
|
||||
"""shift patch"""
|
||||
if len(patch) == 0:
|
||||
return patch
|
||||
|
|
@ -559,19 +559,19 @@ def normalize(patch: Patch) -> Patch:
|
|||
return shift(patch, (-uppermost(patch), -leftmost(patch)))
|
||||
|
||||
|
||||
def dneighbors(loc: IntegerTuple) -> Indices:
|
||||
def dneighbors(loc: Integertuple) -> Indices:
|
||||
"""directly adjacent indices"""
|
||||
return frozenset({(loc[0] - 1, loc[1]), (loc[0] + 1, loc[1]), (loc[0], loc[1] - 1), (loc[0], loc[1] + 1)})
|
||||
|
||||
|
||||
def ineighbors(loc: IntegerTuple) -> Indices:
|
||||
def ineighbors(loc: Integertuple) -> Indices:
|
||||
"""diagonally adjacent indices"""
|
||||
return frozenset(
|
||||
{(loc[0] - 1, loc[1] - 1), (loc[0] - 1, loc[1] + 1), (loc[0] + 1, loc[1] - 1), (loc[0] + 1, loc[1] + 1)}
|
||||
)
|
||||
|
||||
|
||||
def neighbors(loc: IntegerTuple) -> Indices:
|
||||
def neighbors(loc: Integertuple) -> Indices:
|
||||
"""adjacent indices"""
|
||||
return dneighbors(loc) | ineighbors(loc)
|
||||
|
||||
|
|
@ -690,7 +690,7 @@ def bordering(patch: Patch, grid: Grid) -> Boolean:
|
|||
)
|
||||
|
||||
|
||||
def centerofmass(patch: Patch) -> IntegerTuple:
|
||||
def centerofmass(patch: Patch) -> Integertuple:
|
||||
"""center of mass"""
|
||||
return tuple(map(lambda x: sum(x) // len(patch), zip(*toindices(patch))))
|
||||
|
||||
|
|
@ -895,14 +895,14 @@ def subgrid(patch: Patch, grid: Grid) -> Grid:
|
|||
return crop(grid, ulcorner(patch), shape(patch))
|
||||
|
||||
|
||||
def hsplit(grid: Grid, n: Integer) -> Tuple:
|
||||
def hsplit(grid: Grid, n: Integer) -> tuple:
|
||||
"""split grid horizontally"""
|
||||
h, w = len(grid), len(grid[0]) // n
|
||||
offset = len(grid[0]) % n != 0
|
||||
return tuple(crop(grid, (0, w * i + i * offset), (h, w)) for i in range(n))
|
||||
|
||||
|
||||
def vsplit(grid: Grid, n: Integer) -> Tuple:
|
||||
def vsplit(grid: Grid, n: Integer) -> tuple:
|
||||
"""split grid vertically"""
|
||||
h, w = len(grid) // n, len(grid[0])
|
||||
offset = len(grid) % n != 0
|
||||
|
|
@ -933,12 +933,12 @@ def switch(grid: Grid, a: Integer, b: Integer) -> Grid:
|
|||
return tuple(tuple(v if (v != a and v != b) else {a: b, b: a}[v] for v in r) for r in grid)
|
||||
|
||||
|
||||
def center(patch: Patch) -> IntegerTuple:
|
||||
def center(patch: Patch) -> Integertuple:
|
||||
"""center of the patch"""
|
||||
return (uppermost(patch) + height(patch) // 2, leftmost(patch) + width(patch) // 2)
|
||||
|
||||
|
||||
def position(a: Patch, b: Patch) -> IntegerTuple:
|
||||
def position(a: Patch, b: Patch) -> Integertuple:
|
||||
"""relative position between two patches"""
|
||||
ia, ja = center(toindices(a))
|
||||
ib, jb = center(toindices(b))
|
||||
|
|
@ -952,7 +952,7 @@ def position(a: Patch, b: Patch) -> IntegerTuple:
|
|||
return (-1, 1 if ja < jb else -1)
|
||||
|
||||
|
||||
def index(grid: Grid, loc: IntegerTuple) -> Integer:
|
||||
def index(grid: Grid, loc: Integertuple) -> Integer:
|
||||
"""color at location"""
|
||||
i, j = loc
|
||||
h, w = len(grid), len(grid[0])
|
||||
|
|
@ -961,7 +961,7 @@ def index(grid: Grid, loc: IntegerTuple) -> Integer:
|
|||
return grid[loc[0]][loc[1]]
|
||||
|
||||
|
||||
def canvas(value: Integer, dimensions: IntegerTuple) -> Grid:
|
||||
def canvas(value: Integer, dimensions: Integertuple) -> Grid:
|
||||
"""grid construction"""
|
||||
return tuple(tuple(value for j in range(dimensions[1])) for i in range(dimensions[0]))
|
||||
|
||||
|
|
@ -971,7 +971,7 @@ def corners(patch: Patch) -> Indices:
|
|||
return frozenset({ulcorner(patch), urcorner(patch), llcorner(patch), lrcorner(patch)})
|
||||
|
||||
|
||||
def connect(a: IntegerTuple, b: IntegerTuple) -> Indices:
|
||||
def connect(a: Integertuple, b: Integertuple) -> Indices:
|
||||
"""line between two points"""
|
||||
ai, aj = a
|
||||
bi, bj = b
|
||||
|
|
@ -1000,7 +1000,7 @@ def trim(grid: Grid) -> Grid:
|
|||
return tuple(r[1:-1] for r in grid[1:-1])
|
||||
|
||||
|
||||
def move(grid: Grid, obj: Object, offset: IntegerTuple) -> Grid:
|
||||
def move(grid: Grid, obj: Object, offset: Integertuple) -> Grid:
|
||||
"""move object on grid"""
|
||||
return paint(cover(grid, obj), shift(obj, offset))
|
||||
|
||||
|
|
@ -1025,12 +1025,12 @@ def righthalf(grid: Grid) -> Grid:
|
|||
return rot270(bottomhalf(rot90(grid)))
|
||||
|
||||
|
||||
def vfrontier(location: IntegerTuple) -> Indices:
|
||||
def vfrontier(location: Integertuple) -> Indices:
|
||||
"""vertical frontier"""
|
||||
return frozenset((i, location[1]) for i in range(30))
|
||||
|
||||
|
||||
def hfrontier(location: IntegerTuple) -> Indices:
|
||||
def hfrontier(location: Integertuple) -> Indices:
|
||||
"""horizontal frontier"""
|
||||
return frozenset((location[0], j) for j in range(30))
|
||||
|
||||
|
|
@ -1052,7 +1052,7 @@ def delta(patch: Patch) -> Indices:
|
|||
return backdrop(patch) - toindices(patch)
|
||||
|
||||
|
||||
def gravitate(source: Patch, destination: Patch) -> IntegerTuple:
|
||||
def gravitate(source: Patch, destination: Patch) -> Integertuple:
|
||||
"""direction to move source until adjacent to destination"""
|
||||
source_i, source_j = center(source)
|
||||
destination_i, destination_j = center(destination)
|
||||
|
|
@ -1108,7 +1108,7 @@ def box(patch: Patch) -> Indices:
|
|||
return frozenset(vlines | hlines)
|
||||
|
||||
|
||||
def shoot(start: IntegerTuple, direction: IntegerTuple) -> Indices:
|
||||
def shoot(start: Integertuple, direction: Integertuple) -> Indices:
|
||||
"""line from starting point and direction"""
|
||||
return connect(start, (start[0] + 42 * direction[0], start[1] + 42 * direction[1]))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue