Skip to content

Commit

Permalink
Remove Tensor._expand()
Browse files Browse the repository at this point in the history
* Now fully replaced by cached(Tensor)
  • Loading branch information
holl- committed Jan 14, 2024
1 parent d25605a commit 0457f30
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 31 deletions.
5 changes: 2 additions & 3 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._sparse import CompressedSparseMatrix, dense, SparseCoordinateTensor, get_format, to_format, stored_indices, tensor_like, sparse_dims, same_sparsity_pattern, is_sparse, sparse_dot, sparse_sum, sparse_gather, sparse_max, sparse_min
from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, NativeTensor, TensorStack,
custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree,
is_scalar, Layout, expand_tensor, TensorOrTree)
is_scalar, Layout, expand_tensor, TensorOrTree, cached)
from ..backend import default_backend, choose_backend, Backend, get_precision, convert as b_convert, BACKENDS, NoBackendFound, ComputeDevice, NUMPY
from ..backend._dtype import DType, combine_types
from .magic import PhiTreeNode
Expand Down Expand Up @@ -2941,8 +2941,7 @@ def native_function(*natives):
def wrapper(*values: Tensor):
INPUT_TENSORS.clear()
INPUT_TENSORS.extend(values)
for v in values:
v._expand()
values = [cached(v) for v in values]
natives = sum([v._natives() for v in values], ())
results_native = list(traced(*natives))
results = [t._with_natives_replaced(results_native) for t in OUTPUT_TENSORS]
Expand Down
9 changes: 0 additions & 9 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,6 @@ def _from_spec_and_natives(cls, spec: dict, natives: list):
indices = spec['indices']['type']._from_spec_and_natives(spec['indices'], natives)
return SparseCoordinateTensor(indices, values, spec['dense_shape'], spec['can_contain_double_entries'], spec['indices_sorted'], spec['default'])

def _expand(self):
self._values._expand()
self._indices._expand()

def _native_coo_components(self, col_dims: DimFilter, matrix=False):
col_dims = self._shape.only(col_dims)
row_dims = self._dense_shape.without(col_dims)
Expand Down Expand Up @@ -521,11 +517,6 @@ def _from_spec_and_natives(cls, spec: dict, natives: list):
pointers = spec['pointers']['type']._from_spec_and_natives(spec['pointers'], natives)
return CompressedSparseMatrix(indices, pointers, values, spec['uncompressed_dims'], spec['compressed_dims'], spec['default'], spec['uncompressed_offset'], spec['uncompressed_indices'], spec['uncompressed_indices_perm'])

def _expand(self):
self._values._expand()
self._indices._expand()
self._pointers._expand()

def _getitem(self, selection: dict) -> 'Tensor':
batch_selection = {dim: selection[dim] for dim in self._shape.only(tuple(selection)).names}
indices = self._indices[batch_selection]
Expand Down
27 changes: 8 additions & 19 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import traceback
import warnings
from contextlib import contextmanager
from typing import Union, TypeVar
from typing import Union, TypeVar, Sequence

from dataclasses import dataclass
from typing import Tuple, Callable, List
Expand Down Expand Up @@ -811,11 +811,6 @@ def _spec_dict(self) -> dict:
def _from_spec_and_natives(cls, spec: dict, natives: list):
raise NotImplementedError(cls)

def _expand(self):
""" Expands all compressed tensors to their defined size as if they were being used in `Tensor.native()`. """
warnings.warn("Tensor._expand() is deprecated, use cached(Tensor) instead.", DeprecationWarning)
raise NotImplementedError(self.__class__)

def _simplify(self):
""" Does not cache this value but if it is already cached, returns the cached version. """
return self
Expand Down Expand Up @@ -1351,9 +1346,6 @@ def _spec_dict(self) -> dict:
def _from_spec_and_natives(cls, spec: dict, natives: list):
return NativeTensor(natives.pop(0), spec['native_shape'], spec['shape'])

def _expand(self):
self._cache()


class TensorStack(Tensor):
"""
Expand Down Expand Up @@ -1538,12 +1530,6 @@ def _with_natives_replaced(self, natives: list):
tensors = [t._with_natives_replaced(natives) for t in self._tensors]
return TensorStack(tensors, self._stack_dim)

def _expand(self):
if self.requires_broadcast:
for t in self._tensors:
t._expand()
self._cache()

@property
def is_cached(self):
return self._cached is not None
Expand Down Expand Up @@ -1845,7 +1831,7 @@ def custom_op2(x: Union[Tensor, float], y: Union[Tensor, float], l_operator, l_n
return result


def disassemble_tensors(tensors: Union[Tuple[Tensor, ...], List[Tensor]], expand: bool) -> Tuple[tuple, Tuple[Shape], tuple]:
def disassemble_tensors(tensors: Sequence[Tensor], expand: bool) -> Tuple[tuple, Tuple[Shape], tuple]:
"""
Args:
tensors: Tuple or list of Tensors.
Expand All @@ -1856,9 +1842,7 @@ def disassemble_tensors(tensors: Union[Tuple[Tensor, ...], List[Tensor]], expand
specs: Identification primitives from which the tensor can be reconstructed given the natives.
One per tensor.
"""
for t in tensors:
if isinstance(t, TensorStack) or expand:
t._expand()
tensors = [cached(t) if isinstance(t, TensorStack) or expand else t for t in tensors]
natives = sum([t._natives() for t in tensors], ())
shapes = tuple([t.shape for t in tensors])
specs = tuple([t._spec_dict() for t in tensors])
Expand Down Expand Up @@ -1962,6 +1946,7 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor]) -> PhiTreeNodeType


def cached(t: TensorOrTree) -> TensorOrTree:
from ._sparse import SparseCoordinateTensor, CompressedSparseMatrix
assert isinstance(t, (Tensor, PhiTreeNode)), f"All arguments must be Tensors but got {type(t)}"
if isinstance(t, NativeTensor):
return t._cached()
Expand All @@ -1975,6 +1960,10 @@ def cached(t: TensorOrTree) -> TensorOrTree:
natives = [t.native(order=t.shape.names) for t in inners]
native = choose_backend(*natives).stack(natives, axis=t.shape.index(t._stack_dim.name))
return NativeTensor(native, t.shape)
elif isinstance(t, SparseCoordinateTensor):
return SparseCoordinateTensor(cached(t._indices), cached(t._values), t._dense_shape, t._can_contain_double_entries, t._indices_sorted, t._default)
elif isinstance(t, CompressedSparseMatrix):
return CompressedSparseMatrix(cached(t._indices), cached(t._pointers), cached(t._values), t._uncompressed_dims, t._compressed_dims, t._default, t._uncompressed_offset, t._uncompressed_indices, t._uncompressed_indices_perm)
elif isinstance(t, Layout):
return t
elif isinstance(t, PhiTreeNode):
Expand Down

0 comments on commit 0457f30

Please sign in to comment.