diff --git a/include/interface/blas2_interface.h b/include/interface/blas2_interface.h index ee5b73ec8..bf8b4b683 100644 --- a/include/interface/blas2_interface.h +++ b/include/interface/blas2_interface.h @@ -176,34 +176,38 @@ typename sb_handle_t::event_t _symv( ); /*! - @brief Generalised vector product followed by a sum with a rectangular - non-symmetric matrix. - - Generalised vector product followed by a sum with a rectangular non-symmetric - matrix, i.e. computing the mathematical operation: - - A = alpha*x*yT + A - - See the netlib blas interface documentation for more details of the high level - interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html - + * @brief Generalised vector product followed by a sum with a rectangular + * non-symmetric matrix. + * + * Generalised vector product followed by a sum with a rectangular non-symmetric + * matrix, i.e. computing the mathematical operation: + * + * A = alpha*x*yT + A + * + * See the netlib blas interface documentation for more details of the high + * level interface: + * http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html + * + * @param sb_handle SB_handle + * @param _M Number of rows in matrix A + * @param _N Number of columns in matrix A + * @param _alpha Scalar alpha + * @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements + * @param _incx Increment for vector X + * @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements + * @param _incy Increment for vector Y + * @param _mA Input/output matrix A(_lda, n) + * @param _lda Leading dimension of A + * @param _dependencies Vector of events */ template typename sb_handle_t::event_t _ger( - sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc) - index_t _M, // The rows in matrix A - index_t _N, // The cols of matrix A - element_t _alpha, // Scalar alpha - container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X - increment_t _incx, // Increment for vector X - container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y - increment_t _incy, // Increment for vector Y - container_2_t _mA, // (_lda, n) array containing A, the output - index_t _lda, // >max(1, m), Leading dimension of A - const typename sb_handle_t::event_t& _dependencies // Vector of events -); + sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha, + container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy, + container_2_t _mA, index_t _lda, + const typename sb_handle_t::event_t& _dependencies); /*! @brief Generalised vector squaring followed by a sum with a symmetric matrix. @@ -746,35 +750,39 @@ typename sb_handle_t::event_t inline _symv( } /*! - @brief Generalised vector product followed by a sum with a rectangular - non-symmetric matrix. - - Generalised vector product followed by a sum with a rectangular non-symmetric - matrix, i.e. - computing the mathematical operation: - - A = alpha*x*yT + A - - See the netlib blas interface documentation for more details of the high level - interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html - + * @brief Generalised vector product followed by a sum with a rectangular + * non-symmetric matrix. + * + * Generalised vector product followed by a sum with a rectangular non-symmetric + * matrix, i.e. + * computing the mathematical operation: + * + * A = alpha*x*yT + A + * + * See the netlib blas interface documentation for more details of the high + * level interface: + * http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html + * + * @param sb_handle SB_handle + * @param _M Number of rows in matrix A + * @param _N Number of columns in matrix A + * @param _alpha Scalar alpha + * @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements + * @param _incx Increment for vector X + * @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements + * @param _incy Increment for vector Y + * @param _mA Input/output matrix A(_lda, n) + * @param _lda Leading dimension of A + * @param _dependencies Vector of events */ template typename sb_handle_t::event_t inline _ger( - sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc) - index_t _M, // The rows in matrix M - index_t _N, // The rows of matrix N - element_t _alpha, // Scalar alpha - container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X - increment_t _incx, // Increment for vector X - container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y - increment_t _incy, // Increment for vector Y - container_2_t _mA, // (_lda, n) array containing A, the output - index_t _lda, // >max(1, m), Leading dimension of A - const typename sb_handle_t::event_t& _dependencies = {} // Vector of events -) { + sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha, + container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy, + container_2_t _mA, index_t _lda, + const typename sb_handle_t::event_t& _dependencies = {}) { return internal::_ger(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda, _dependencies); } diff --git a/include/operations/blas2_trees.h b/include/operations/blas2_trees.h index 34937283e..9dbbedebb 100644 --- a/include/operations/blas2_trees.h +++ b/include/operations/blas2_trees.h @@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) { subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_, sync_); } +/** + * @struct Ger + * @brief Tree node representing the sum of scalar-vector-vector product with a + * matrix, i.e., it computes lhs_ such that + * + * lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_ + * + * @param lhs_ input/output matrix + * @param scalar_ value for scaling vector product + * @param rhs_1_ first input vector + * @param rhs_2_ second input vector + * @param nRowsWG_ rows of the workgroup tile + * @param nColsWG_ cols of the workgroup tile + * @param nWG_row_ number of tiles per global size row + * @param nWG_col_ number of tiles per global size column + * + */ +template +struct Ger { + using value_t = typename rhs_2_t::value_t; + using index_t = typename rhs_2_t::index_t; + + lhs_t lhs_; + value_t scalar_; + rhs_1_t rhs_1_; + rhs_2_t rhs_2_; + index_t nRowsWG_; + index_t nColsWG_; + index_t nWG_row_; + index_t nWG_col_; + + Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG, + index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col); + + index_t get_size() const; + bool valid_thread(cl::sycl::nd_item<1> ndItem) const; + value_t eval(index_t i); + value_t eval(cl::sycl::nd_item<1> ndItem); + template + value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem); + void bind(cl::sycl::handler &h); + void adjust_access_displacement(); +}; + +/*! + @brief Generator/factory for GER trees. + */ +template +Ger make_ger(lhs_t &lhs_, + typename lhs_t::value_t scalar_, + rhs_1_t &rhs_1_, rhs_2_t &rhs_2_, + typename rhs_2_t::index_t nRowsWG_, + typename rhs_2_t::index_t nColsWG_, + typename rhs_2_t::index_t nWG_row_, + typename rhs_2_t::index_t nWG_col_) { + return Ger(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_, + nColsWG_, nWG_row_, nWG_col_); +} /**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/ // template diff --git a/src/interface/blas2_interface.hpp b/src/interface/blas2_interface.hpp index 71dbee066..6f43d5300 100644 --- a/src/interface/blas2_interface.hpp +++ b/src/interface/blas2_interface.hpp @@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl( container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy, container_t2 _mA, index_t _lda, const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0, - index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) { + bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) { index_t M = _M; index_t N = _N; auto mA = make_matrix_view(_mA, M, N, _lda); @@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl( typename VectorViewType::type vy = make_vector_view(_vy, _incy, N); - const index_t localSize = - (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize; - const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG); + _localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize; + _nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG; + _nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG; - const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG); + assert(_localSize % _nRowsWG == 0); + assert((_nRowsWG * _nColsWG) % _localSize == 0); + assert(_nColsWG % (_localSize / _nRowsWG) == 0); - const index_t scratchPadSize = - (_localSize == 0) ? localSize : _scratchPadSize; + if (_useLocalMem) { + assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize)); + } else { + std::vector subgroup_sizes = + sb_handle.get_queue() + .get_device() + .template get_info(); + size_t min_subgroup_size = *subgroup_sizes.begin(); + size_t max_subgroup_size = *subgroup_sizes.rbegin(); + assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size); + assert(_nRowsWG % max_subgroup_size == 0); + } - const index_t nWGPerCol = (N - 1) / nColsWG + 1; - const index_t nWGPerRow = (M - 1) / nRowsWG + 1; - const index_t globalSize = localSize * nWGPerRow * nWGPerCol; + const index_t nWGPerCol = (N - 1) / _nColsWG + 1; + const index_t nWGPerRow = (M - 1) / _nRowsWG + 1; + const index_t globalSize = _localSize * nWGPerRow * nWGPerCol; typename sb_handle_t::event_t ret; auto assignOp = - make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize); - return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize, - _dependencies); + make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol); + + return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize, + _nRowsWG + _nColsWG, _dependencies) + : sb_handle.execute(assignOp, _localSize, globalSize, + _dependencies); } /*! _SYR. @@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger( container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy, container_t2 _mA, index_t _lda, const typename sb_handle_t::event_t& _dependencies) { - // TODO: Here we can use some heuristics to select localn global, local, and - // scratch size per device + index_t localSize = 0; + bool useLocalMem = true; + index_t nRowsWG = 0; + index_t nColsWG = 0; + +#if defined(INTEL_GPU) + localSize = 32; + useLocalMem = false; + nRowsWG = 32; + nColsWG = 8; +#elif defined(NVIDIA_GPU) + localSize = 256; + useLocalMem = (_N < 8192 && _M < 8192) ? false : true; + nRowsWG = 32; + nColsWG = 32; +#elif defined(AMD_GPU) + localSize = (_N < 8192 && _M < 8192) ? 512 : 256; + useLocalMem = (_N < 8192 && _M < 8192) ? false : true; + nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128; + nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256; +#endif + return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda, - _dependencies); + _dependencies, localSize, useLocalMem, nRowsWG, nColsWG); } template +PORTBLAS_INLINE Ger::Ger( + lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG, + index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col) + : lhs_(_l), + scalar_(_scl), + rhs_1_(_r1), + rhs_2_(_r2), + nRowsWG_(_nRowsWG), + nColsWG_(_nColsWG), + nWG_row_(_nWG_row), + nWG_col_(_nWG_col) {} + +template +PORTBLAS_INLINE typename Ger::index_t +Ger::get_size() const { + return rhs_1_.get_size(); +} +template +PORTBLAS_INLINE bool Ger::valid_thread( + cl::sycl::nd_item<1> ndItem) const { + return true; +} + +template +PORTBLAS_INLINE typename Ger::value_t +Ger::eval(cl::sycl::nd_item<1> ndItem) { + using index_t = typename Ger::index_t; + + const index_t subgroup_size = ndItem.get_sub_group().get_local_range().get(0); + const index_t subgroups_per_col = nRowsWG_ / subgroup_size; + const index_t subgroups_per_group = + ndItem.get_sub_group().get_group_range().get(0); + + const index_t group_size = ndItem.get_local_range(0); + + // col_per_workitem <= subgroup_size + const index_t col_per_workitem = nColsWG_ * nRowsWG_ / group_size; + + const index_t group_id = ndItem.get_group(0); + const index_t idWFR = group_id % nWG_row_; + const index_t idWFC = group_id / nWG_row_; + + const index_t subgroup_id = ndItem.get_sub_group().get_group_id().get(0); + const index_t subgroup_local_id = + ndItem.get_sub_group().get_local_id().get(0); + + const index_t id_row0 = idWFR * nRowsWG_ + + subgroup_size * (subgroup_id % subgroups_per_col) + + subgroup_local_id; + const index_t id_col0 = + idWFC * nColsWG_ + col_per_workitem * (subgroup_id / subgroups_per_col); + + const index_t dimR = lhs_.get_size_row(); + const index_t dimC = lhs_.get_size_col(); + const bool id_row_active = id_row0 < dimR; + +#ifndef __ADAPTIVECPP__ + const value_t rhs_2 = (subgroup_local_id < col_per_workitem && + id_col0 + subgroup_local_id < dimC) + ? rhs_2_.eval(id_col0 + subgroup_local_id) + : 0; +#endif + + const value_t scal_rhs_1 = id_row_active ? scalar_ * rhs_1_.eval(id_row0) : 0; + + value_t prefetch_lhs_ = + (id_row_active && id_col0 < dimC) ? lhs_.eval(id_row0, id_col0) : 0; + + for (index_t sub_id_col = 0; sub_id_col < col_per_workitem; sub_id_col++) { + const value_t rhs_2_sub_id_col = +#ifndef __ADAPTIVECPP__ + cl::sycl::group_broadcast(ndItem.get_sub_group(), rhs_2, sub_id_col); +#else + rhs_2_.eval(id_col0 + sub_id_col); +#endif + if (id_row_active && id_col0 + sub_id_col < dimC) { + lhs_.eval(id_row0, id_col0 + sub_id_col) = + prefetch_lhs_ + scal_rhs_1 * rhs_2_sub_id_col; + prefetch_lhs_ = (id_col0 + sub_id_col + 1 < dimC) + ? lhs_.eval(id_row0, id_col0 + sub_id_col + 1) + : 0; + } + } + + return 0; +} + +template +template +PORTBLAS_INLINE typename Ger::value_t +Ger::eval(sharedT shrMem, + cl::sycl::nd_item<1> ndItem) { + using index_t = typename Ger::index_t; + + const index_t group_id = ndItem.get_group(0); + const index_t idWFR = group_id % nWG_row_; + const index_t idWFC = group_id / nWG_row_; + const index_t frs_row = idWFR * nRowsWG_; + const index_t group_local_id = ndItem.get_local_id(0); + + // group_size%nRowsWG_ == 0 + const index_t id_row0 = group_local_id % nRowsWG_; + const index_t id_row1 = frs_row + id_row0; + + index_t frs_col = idWFC * nColsWG_; + + const index_t dimR = lhs_.get_size_row(); + const index_t dimC = lhs_.get_size_col(); + + value_t *l_rhs_1 = shrMem.localAcc.get_pointer(); + value_t *l_rhs_2 = shrMem.localAcc.get_pointer() + nRowsWG_; + + // nRowsWG_ <= group_size + if (group_local_id < nRowsWG_) + l_rhs_1[group_local_id] = + (frs_row + group_local_id < dimR) + ? scalar_ * rhs_1_.eval(frs_row + group_local_id) + : 0; + + // nColsWG_ <= group_size + if (group_local_id < nColsWG_) + l_rhs_2[group_local_id] = (frs_col + group_local_id < dimC) + ? rhs_2_.eval(frs_col + group_local_id) + : 0; + + const index_t group_size = ndItem.get_local_range(0); + + // nRowsWG_ * nColsWG_ % group_size == 0 + const index_t col_per_workitem = nRowsWG_ * nColsWG_ / group_size; + const index_t subgroup_col_id = group_local_id / nRowsWG_; + + const index_t id_col0 = subgroup_col_id * col_per_workitem; + const index_t id_col1 = frs_col + id_col0; + + value_t prefetch_lhs_ = + (id_row1 < dimR && id_col1 < dimC) ? lhs_.eval(id_row1, id_col1) : 0; + + ndItem.barrier(cl::sycl::access::fence_space::local_space); + + for (index_t id_col = 0; id_col < col_per_workitem; id_col++) { + const value_t val = l_rhs_1[id_row0] * l_rhs_2[id_col0 + id_col]; + if (id_row1 < dimR && id_col1 + id_col < dimC) { + lhs_.eval(id_row1, id_col1 + id_col) = prefetch_lhs_ + val; + prefetch_lhs_ = (id_col1 + id_col + 1 < dimC) + ? lhs_.eval(id_row1, id_col1 + id_col + 1) + : 0; + } + } + + return 0; +} + +template +PORTBLAS_INLINE void Ger::bind(cl::sycl::handler &h) { + lhs_.bind(h); + rhs_1_.bind(h); + rhs_2_.bind(h); +} +template +PORTBLAS_INLINE void +Ger::adjust_access_displacement() { + lhs_.adjust_access_displacement(); + rhs_1_.adjust_access_displacement(); + rhs_2_.adjust_access_displacement(); +} + /**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/ // template template const auto combi = ::testing::Combine(::testing::Values("usm", "buf"), // allocation type - ::testing::Values(11, 1023), // m - ::testing::Values(14, 1010), // n - ::testing::Values(0.0, 1.5), // alpha + ::testing::Values(11, 1023, 8888, 10968), // m + ::testing::Values(14, 1010, 9999), // n + ::testing::Values(0, 1.5), // alpha ::testing::Values(2), // incX ::testing::Values(3), // incY ::testing::Values(2) // lda_mul