Skip to content

Commit

Permalink
Sparse padding for dense tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 26, 2023
1 parent 9d87a03 commit 95921d4
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,15 @@ def _upper_mask(self, shape, widths, bound_dim, bound_lo, bound_hi, i):
return mask

def sparse_pad_values(self, value: Tensor, connectivity: Tensor, dim: str, **kwargs) -> Tensor:
from ._sparse import stored_indices
from ._ops import arange
from ._sparse import stored_indices, is_sparse
from ._ops import arange, nonzero
dual_dim = dual(value).name
# --- Gather the edge values ---
indices = stored_indices(connectivity, invalid='discard')
primal_dim = [n for n in indices.index.item_names if not n.startswith('~')][0]
if is_sparse(connectivity):
indices = stored_indices(connectivity, invalid='discard')
else:
indices = nonzero(connectivity)
primal_dim = [n for n in channel(indices).item_names[0] if not n.startswith('~')][0]
assert primal_dim not in value.shape, f"sparse_pad_values only implemented for vectors, not matrices"
gathered = value[{dual_dim: indices[primal_dim]}]
# --- Scatter, but knowing there is only one entry per row & col, we can simply permute ---
Expand Down

0 comments on commit 95921d4

Please sign in to comment.