Skip to content

Commit

Permalink
Merge branch 'kmp5/debug/fix_flatten' into kmp5/feature/cp-bcd
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Jan 10, 2024
2 parents 8bcee70 + 51eaab8 commit d74188f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 75 deletions.
29 changes: 28 additions & 1 deletion btas/generic/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,36 @@ namespace btas {
swap_to_first(tensor_ref, n, true);

#else // BTAS_HAS_CBLAS
// // Computes the Khatri-Rao product intermediate
auto KhatriRao = this->generate_KRP(n, rank, true);

// moves mode n of the reference tensor to the front to simplify contraction
std::vector<ind_t> tref_indices, KRP_dims, An_indices;

// resize the Khatri-Rao product to the proper dimensions
for (size_t i = 0; i < ndim; i++) {
tref_indices.push_back(i);
if(i == n)
continue;
KRP_dims.push_back(tensor_ref.extent(i));
}
KRP_dims.push_back(rank);
KhatriRao.resize(KRP_dims);
KRP_dims.clear();

An_indices.push_back(n);
An_indices.push_back(ndim);
for (size_t i = 0; i < ndim; i++) {
if(i == n)
continue;
KRP_dims.push_back(i);
}
KRP_dims.push_back(ndim);
contract(this->one, tensor_ref, tref_indices, KhatriRao, KRP_dims, this->zero, temp, An_indices);

// without MKL program cannot perform the swapping algorithm, must compute
// flattened intermediate
gemm(blas::Op::NoTrans, blas::Op::NoTrans, this->one, flatten(tensor_ref, n), this->generate_KRP(n, rank, true), this->zero, temp);
// gemm(blas::Op::NoTrans, blas::Op::NoTrans, this->one, new_flatten(tensor_ref, n), this->generate_KRP(n, rank, true), this->zero, temp);
#endif

if(lambda != 0){
Expand Down
96 changes: 22 additions & 74 deletions btas/generic/flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,86 +9,34 @@ namespace btas {
/// \f[ A(I_1, I_2, I_3, ..., I_{mode}, ..., I_N) -> A(I_{mode}, J)\f]
/// where \f$J = I_1 * I_2 * ...I_{mode-1} * I_{mode+1} * ... * I_N.\f$
/// \return Matrix with dimension \f$(I_{mode}, J)\f$

template<typename Tensor>
Tensor flatten(const Tensor &A, size_t mode) {
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;

if (mode >= A.rank()) BTAS_EXCEPTION("Cannot flatten along mode outside of A.rank()");

// make X the correct size
Tensor X(A.extent(mode), A.range().area() / A.extent(mode));

ord_t indexi = 0, indexj = 0;
size_t ndim = A.rank();
// J is the new step size found by removing the mode of interest
std::vector<ord_t> J(ndim, 1);
for (size_t i = 0; i < ndim; ++i)
if (i != mode)
for (size_t m = 0; m < i; ++m)
if (m != mode)
J[i] *= A.extent(m);

auto tensor_itr = A.begin();

// Fill X with the correct values
fill(A, 0, X, mode, indexi, indexj, J, tensor_itr);

// return the flattened matrix
return X;
}

/// following the formula for flattening layed out by Kolda and Bader
/// <a href=http://epubs.siam.org/doi/pdf/10.1137/07070111X> See reference. </a>
/// Recursive method utilized by flatten.\n **Important** if you want to flatten a tensor
/// call flatten, not fill.

/// \param[in] A The reference tensor to be flattened
/// \param[in] depth The recursion depth. Should not exceed the A.rank()
/// \param[in, out] X In: An empty matrix to be filled with correct
/// elements of \c A flattened on the \c mode fiber. Should be size \f$ (I_{mode}, J)\f$
/// Out: The flattened A matrix along the \c mode fiber \param[in]
/// mode The mode which A is to be flattened. \param[in] indexi The row index of
/// matrix X \param[in] indexj The column index of matrix X \param[in] J The
/// step size for the row dimension of X \param[in] tensor_itr An iterator of \c A.
/// The value of the iterator is placed in the correct position of X using
/// recursive calls of fill().

template<typename Tensor, typename iterator, typename ord_t>
void fill(const Tensor &A, size_t depth, Tensor &X, size_t mode,
ord_t indexi, ord_t indexj, const std::vector<ord_t> &J, iterator &tensor_itr) {
template<typename Tensor>
Tensor flatten(Tensor A, size_t mode) {
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
using ind_t = typename Tensor::range_type::index_type::value_type;
size_t ndim = A.rank();
if (depth < ndim) {
// We are going to first make the order N tensor into a order 3 tensor with
// (modes before `mode`, `mode`, modes after `mode`

// Creates a for loop based on the number of modes A has
for (ind_t i = 0; i < A.extent(depth); ++i) {
auto dim_mode = A.extent(mode);
Tensor flat(dim_mode, A.range().area() / dim_mode);
size_t ndim = A.rank();
ord_t dim1 = 1, dim3 = 1;
for (ind_t i = 0; i < ndim; ++i) {
if (i < mode)
dim1 *= A.extent(i);
else if (i > mode)
dim3 *= A.extent(i);
}

A.resize(Range{Range1{dim1}, Range1{dim_mode}, Range1{dim3}});

// use the for loop to find the column dimension index
if (depth != mode) {
indexj += i * J[depth]; // column matrix index
for (ord_t i = 0; i < dim1; ++i) {
for (ind_t j = 0; j < dim_mode; ++j) {
for (ord_t k = 0; k < dim3; ++k) {
flat(j, i * dim3 + k) = A(i,j,k);
}

// if this depth is the mode being flattened use the for loop to find the
// row dimension
else {
indexi = i; // row matrix index
}

fill(A, depth + 1, X, mode, indexi, indexj, J, tensor_itr);

// remove the indexing from earlier in this loop.
if (depth != mode)
indexj -= i * J[depth];
}
}

// When depth steps out of the number of dimensions, set X to be the correct
// value from the iterator then increment the iterator.
else {
X(indexi, indexj) = *tensor_itr;
tensor_itr++;
}
return flat;
}

} // namespace btas
Expand Down
1 change: 1 addition & 0 deletions btas/generic/linear_algebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Tensor pseudoInverse(Tensor & A, bool & fast_pI) {
// Compute the matrix A^-1 from the inverted singular values and the U and
// V^T provided by the SVD
gemm(blas::Op::NoTrans, blas::Op::NoTrans, 1.0, U, s_inv, 0.0, s_);
U = Tensor(Range{Range1{row}, Range1{col}});
gemm(blas::Op::NoTrans, blas::Op::NoTrans, 1.0, s_, Vt, 0.0, U);

return U;
Expand Down

0 comments on commit d74188f

Please sign in to comment.