Skip to content

Commit

Permalink
Add at_min_neighbor(), at_max_neighbor()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 20, 2023
1 parent 010cd31 commit a75acb2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
normalize_to,
l1_loss, l2_loss, frequency_loss,
spatial_gradient, laplace,
neighbor_reduce, neighbor_mean, neighbor_sum, neighbor_max, neighbor_min,
neighbor_reduce, neighbor_mean, neighbor_sum, neighbor_max, neighbor_min, at_min_neighbor, at_max_neighbor,
fourier_laplace, fourier_poisson, abs_square,
downsample2x, upsample2x, sample_subgrid,
masked_fill, finite_fill
Expand Down
43 changes: 39 additions & 4 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,15 @@ def shift(x: Tensor,

def neighbor_reduce(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""
Computes the mean of two neighboring values along each dimension in `dim`.
The result tensor has one entry less than `grid` in each averaged dimension unless `padding` is specified..
Computes the sum/mean/min/max/prod/etc. of two neighboring values along each dimension in `dim`.
The result tensor has one entry less than `grid` in each averaged dimension unless `padding` is specified.
With two `dims`, computes the mean of 4 values, in 3D, the mean of 8 values.
Args:
reduce_fun: Reduction function, such as `sum`, `mean`, `max`, `min`, `prod`.
grid: Values to average.
dims: Dimensions along which neighbors should be averaged.
grid: Values to reduce.
dims: Dimensions along which neighbors should be reduced.
padding: Padding at the upper edges of `grid` along `dims'. If not `None`, the result tensor will have the same shape as `grid`.
Returns:
Expand Down Expand Up @@ -545,6 +545,41 @@ def neighbor_min(grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapo
return neighbor_reduce(math.min_, grid, dims, padding)


def at_neighbor_where(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""
Computes the mean of two neighboring values along each dimension in `dim`.
The result tensor has one entry less than `grid` in each averaged dimension unless `padding` is specified..
With two `dims`, computes the mean of 4 values, in 3D, the mean of 8 values.
Args:
reduce_fun: Reduction function, such as `at_max`, `at_min`.
grid: Values to average.
dims: Dimensions along which neighbors should be averaged.
padding: Padding at the upper edges of `grid` along `dims'. If not `None`, the result tensor will have the same shape as `grid`.
Returns:
`Tensor`
"""
result = grid
dims = grid.shape.only(dims)
for dim in dims:
lr = stack(shift(result, (0, 1), dim, padding, None), batch('_reduce'))
other_tensors = [stack(shift(t, (0, 1), dim, padding, None), batch('_reduce')) for t in other_tensors]
result, *other_tensors = reduce_fun(lr, '_reduce', lr, *other_tensors)
return other_tensors[0] if len(other_tensors) == 1 else other_tensors


def at_max_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_max`."""
return at_neighbor_where(math.at_max, grid, dims, *other_tensors, padding=padding)


def at_min_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_min`."""
return at_neighbor_where(math.at_min, grid, dims, *other_tensors, padding=padding)



def masked_fill(values: Tensor, valid: Tensor, distance: int = 1) -> Tuple[Tensor, Tensor]:
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/commit/math/test__nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,15 @@ def test_neighbor_reduce(self):
math.assert_close([0, 2, 0], math.neighbor_min(grid1))
grid2 = wrap([[0, 1], [2, 3]], spatial('y,x'))
math.assert_close([[6]], math.neighbor_sum(grid2))

def test_at_min_neighbor(self):
x = math.range(spatial(x=4))
x_min, neg_x_min = math.at_min_neighbor(x, 'x', x, -x)
math.assert_close([0, 1, 2], x_min)
math.assert_close([0, -1, -2], neg_x_min)

def test_at_max_neighbor(self):
x = math.range(spatial(x=4))
x_min, neg_x_min = math.at_max_neighbor(x, 'x', x, -x)
math.assert_close([1, 2, 3], x_min)
math.assert_close([-1, -2, -3], neg_x_min)

0 comments on commit a75acb2

Please sign in to comment.