diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py index 1e3502fb..686085a3 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py @@ -231,14 +231,18 @@ def __init__(self, tensor): self.shape = get_onnx_tensor_shape(self.tensor) self.dtype = get_onnx_tensor_dtype(self.tensor) self.nbytes = misc.volume(self.shape) * get_itemsize(self.dtype) + self._cached_values = None # Initialize the cache def load(self): """ - Load a numpy array from the underlying tensor values. + Load a numpy array from the underlying tensor values, using cache. Returns: np.array: A numpy array containing the values of the tensor. """ + if self._cached_values is not None: + return self._cached_values # Return cached data if available + import onnx import onnx.numpy_helper from onnx_graphsurgeon.importers.onnx_importer import ( @@ -254,7 +258,8 @@ def load(self): f"If this is not what you intended, please avoid accessing the values of this constant tensor." ) - return np.array(onnx.numpy_helper.to_array(self.tensor)) + self._cached_values = np.array(onnx.numpy_helper.to_array(self.tensor)) + return self._cached_values def __str__(self): return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype) @@ -268,13 +273,20 @@ class SparseValues(LazyValues): A special object that represents constant tensor values that is sparse """ + def __init__(self, tensor): + super().__init__(tensor) + self._cached_values = None # Initialize the cache + def load(self): """ - Load a numpy array from the sparse structure. + Load a numpy array from the sparse structure, using cache. Returns: np.array: A numpy array containing the values of the tensor. """ + if self._cached_values is not None: + return self._cached_values # Return cached data if available + import onnx import onnx.numpy_helper from onnx_graphsurgeon.importers.onnx_importer import ( @@ -316,7 +328,8 @@ def load(self): f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}" ) - return values + self._cached_values = values + return self._cached_values def __str__(self): return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype)