diff --git a/libs/libvtrutil/src/vtr_ndmatrix.h b/libs/libvtrutil/src/vtr_ndmatrix.h index 57571cc865..b7d6f030d5 100644 --- a/libs/libvtrutil/src/vtr_ndmatrix.h +++ b/libs/libvtrutil/src/vtr_ndmatrix.h @@ -30,13 +30,14 @@ class NdMatrixProxy { * @brief Construct a matrix proxy object * * @param dim_sizes: Array of dimension sizes - * @param idim: The dimension associated with this proxy * @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension) - * @param start: Pointer to the start of the sub-matrix this proxy represents + * @param offset: The offset from the start that this sub-matrix starts at. + * @param start: Pointer to the start of the base NDMatrix of this proxy */ - NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_strides, T* start) + NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_strides, size_t offset, const std::unique_ptr& start) : dim_sizes_(dim_sizes) , dim_strides_(dim_strides) + , offset_(offset) , start_(start) {} NdMatrixProxy& operator=(const NdMatrixProxy& other) = delete; @@ -50,7 +51,8 @@ class NdMatrixProxy { return NdMatrixProxy( dim_sizes_ + 1, // Pass the dimension information dim_strides_ + 1, // Pass the stride for the next dimension - start_ + dim_strides_[0] * index); // Advance to index in this dimension + offset_ + dim_strides_[0] * index, // Advance to index in this dimension + start_); // Pass the base pointer. } ///@brief [] operator @@ -60,9 +62,22 @@ class NdMatrixProxy { } private: + /// @brief The sizes of each dimension of this proxy. This is an array of + /// length N. const size_t* dim_sizes_; + + /// @brief The stride of each dimension of this proxy. This is an array of + /// length N. const size_t* dim_strides_; - T* start_; + + /// @brief The offset from the base NDMatrix object that this sub-matrix + /// starts at. + size_t offset_; + + /// @brief The pointer to the start of the base NDMatrix data. Since the + /// base NDMatrix object owns the memory, we hold onto a reference + /// to its unique pointer. This is safer than passing a bare pointer. + const std::unique_ptr& start_; }; ///@brief Base case: 1-dimensional array @@ -74,11 +89,13 @@ class NdMatrixProxy { * * @param dim_sizes: Array of dimension sizes * @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension) - * @param start: Pointer to the start of the sub-matrix this proxy represents + * @param offset: The offset from the start that this sub-matrix starts at. + * @param start: Pointer to the start of the base NDMatrix of this proxy */ - NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_stride, T* start) + NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_stride, size_t offset, const std::unique_ptr& start) : dim_sizes_(dim_sizes) , dim_strides_(dim_stride) + , offset_(offset) , start_(start) {} NdMatrixProxy& operator=(const NdMatrixProxy& other) = delete; @@ -89,7 +106,7 @@ class NdMatrixProxy { VTR_ASSERT_SAFE_MSG(index < dim_sizes_[0], "Index out of range (above dimension maximum)"); //Base case - return start_[index]; + return start_[offset_ + index]; } ///@brief [] operator @@ -108,7 +125,7 @@ class NdMatrixProxy { * not to clobber elements in other dimensions */ const T* data() const { - return start_; + return start_.get() + offset_; } ///@brief same as above but allow update the value @@ -118,9 +135,22 @@ class NdMatrixProxy { } private: + /// @brief The sizes of each dimension of this proxy. This is an array of + /// length N. const size_t* dim_sizes_; + + /// @brief The stride of each dimension of this proxy. This is an array of + /// length N. const size_t* dim_strides_; - T* start_; + + /// @brief The offset from the base NDMatrix object that this sub-matrix + /// starts at. + size_t offset_; + + /// @brief The pointer to the start of the base NDMatrix data. Since the + /// base NDMatrix object owns the memory, we hold onto a reference + /// to its unique pointer. This is safer than passing a bare pointer. + const std::unique_ptr& start_; }; /** @@ -359,7 +389,8 @@ class NdMatrix : public NdMatrixBase { return NdMatrixProxy( this->dim_sizes_.data() + 1, //Pass the dimension information this->dim_strides_.data() + 1, //Pass the stride for the next dimension - this->data_.get() + this->dim_strides_[0] * index); //Advance to index in this dimension + this->dim_strides_[0] * index, //Advance to index in this dimension + this->data_); //Pass the base pointer } /**