diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index 69283cf54f..589208b63b 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -145,7 +145,7 @@ template <typename T> class stack_strided_t namespace su_ns = dpctl::tensor::sycl_utils; -using nwiT = std::uint16_t; +using nwiT = std::uint32_t; template <typename inputT, typename outputT, @@ -156,7 +156,18 @@ template <typename inputT, typename TransformerT, typename ScanOpT, bool include_initial> -class inclusive_scan_iter_local_scan_krn; +class inclusive_scan_iter_local_scan_blocked_krn; + +template <typename inputT, + typename outputT, + nwiT n_wi, + typename IterIndexerT, + typename InpIndexerT, + typename OutIndexerT, + typename TransformerT, + typename ScanOpT, + bool include_initial> +class inclusive_scan_iter_local_scan_striped_krn; template <typename inputT, typename outputT, @@ -177,22 +188,22 @@ template <typename inputT, typename ScanOpT, bool include_initial = false> sycl::event -inclusive_scan_base_step(sycl::queue &exec_q, - const std::size_t wg_size, - const std::size_t iter_nelems, - const std::size_t acc_nelems, - const inputT *input, - outputT *output, - const std::size_t s0, - const std::size_t s1, - const IterIndexerT &iter_indexer, - const InpIndexerT &inp_indexer, - const OutIndexerT &out_indexer, - TransformerT transformer, - const ScanOpT &scan_op, - outputT identity, - std::size_t &acc_groups, - const std::vector<sycl::event> &depends = {}) +inclusive_scan_base_step_blocked(sycl::queue &exec_q, + const std::uint32_t wg_size, + const std::size_t iter_nelems, + const std::size_t acc_nelems, + const inputT *input, + outputT *output, + const std::size_t s0, + const std::size_t s1, + const IterIndexerT &iter_indexer, + const InpIndexerT &inp_indexer, + const OutIndexerT &out_indexer, + TransformerT transformer, + const ScanOpT &scan_op, + outputT identity, + std::size_t &acc_groups, + const std::vector<sycl::event> &depends = {}) { acc_groups = ceiling_quotient<std::size_t>(acc_nelems, n_wi * wg_size); @@ -208,7 +219,7 @@ inclusive_scan_base_step(sycl::queue &exec_q, slmT slm_iscan_tmp(lws, cgh); - using KernelName = inclusive_scan_iter_local_scan_krn< + using KernelName = inclusive_scan_iter_local_scan_blocked_krn< inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT, TransformerT, ScanOpT, include_initial>; @@ -218,6 +229,7 @@ inclusive_scan_base_step(sycl::queue &exec_q, const std::size_t gid = it.get_global_id(0); const std::size_t lid = it.get_local_id(0); + const std::uint32_t wg_size = it.get_local_range(0); const std::size_t reduce_chunks = acc_groups * wg_size; const std::size_t iter_gid = gid / reduce_chunks; const std::size_t chunk_gid = gid - (iter_gid * reduce_chunks); @@ -268,7 +280,8 @@ inclusive_scan_base_step(sycl::queue &exec_q, } else { wg_iscan_val = su_ns::custom_inclusive_scan_over_group( - it.get_group(), slm_iscan_tmp, local_iscan.back(), scan_op); + it.get_group(), it.get_sub_group(), slm_iscan_tmp, + local_iscan.back(), identity, scan_op); // ensure all finished reading from SLM, to avoid race condition // with subsequent writes into SLM it.barrier(sycl::access::fence_space::local_space); @@ -276,11 +289,11 @@ inclusive_scan_base_step(sycl::queue &exec_q, slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val; it.barrier(sycl::access::fence_space::local_space); - outputT addand = (lid == 0) ? identity : slm_iscan_tmp[lid]; + const outputT modifier = (lid == 0) ? identity : slm_iscan_tmp[lid]; #pragma unroll for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { - local_iscan[m_wi] = scan_op(local_iscan[m_wi], addand); + local_iscan[m_wi] = scan_op(local_iscan[m_wi], modifier); } const std::size_t start = std::min(i, acc_nelems); @@ -296,6 +309,249 @@ inclusive_scan_base_step(sycl::queue &exec_q, return inc_scan_phase1_ev; } +template <typename inputT, + typename outputT, + nwiT n_wi, + typename IterIndexerT, + typename InpIndexerT, + typename OutIndexerT, + typename TransformerT, + typename ScanOpT, + bool include_initial = false> +sycl::event +inclusive_scan_base_step_striped(sycl::queue &exec_q, + const std::uint32_t wg_size, + const std::size_t iter_nelems, + const std::size_t acc_nelems, + const inputT *input, + outputT *output, + const std::size_t s0, + const std::size_t s1, + const IterIndexerT &iter_indexer, + const InpIndexerT &inp_indexer, + const OutIndexerT &out_indexer, + TransformerT transformer, + const ScanOpT &scan_op, + outputT identity, + std::size_t &acc_groups, + const std::vector<sycl::event> &depends = {}) +{ + const std::uint32_t reduce_nelems_per_wg = n_wi * wg_size; + acc_groups = + ceiling_quotient<std::size_t>(acc_nelems, reduce_nelems_per_wg); + + sycl::event inc_scan_phase1_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using slmT = sycl::local_accessor<outputT, 1>; + + const auto &gRange = sycl::range<1>{iter_nelems * acc_groups * wg_size}; + const auto &lRange = sycl::range<1>{wg_size}; + + const auto &ndRange = sycl::nd_range<1>{gRange, lRange}; + + slmT slm_iscan_tmp(reduce_nelems_per_wg, cgh); + + using KernelName = inclusive_scan_iter_local_scan_striped_krn< + inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT, + TransformerT, ScanOpT, include_initial>; + + cgh.parallel_for<KernelName>(ndRange, [=, slm_iscan_tmp = + std::move(slm_iscan_tmp)]( + sycl::nd_item<1> it) { + const std::uint32_t lid = it.get_local_linear_id(); + const std::uint32_t wg_size = it.get_local_range(0); + + const auto &sg = it.get_sub_group(); + const std::uint32_t sgSize = sg.get_max_local_range()[0]; + const std::size_t sgroup_id = sg.get_group_id()[0]; + const std::uint32_t lane_id = sg.get_local_id()[0]; + + const std::size_t flat_group_id = it.get_group(0); + const std::size_t iter_gid = flat_group_id / acc_groups; + const std::size_t acc_group_id = + flat_group_id - (iter_gid * acc_groups); + + const auto &iter_offsets = iter_indexer(iter_gid); + const auto &inp_iter_offset = iter_offsets.get_first_offset(); + const auto &out_iter_offset = iter_offsets.get_second_offset(); + + std::array<outputT, n_wi> local_iscan{}; + + const std::size_t inp_id0 = acc_group_id * n_wi * wg_size + + sgroup_id * n_wi * sgSize + lane_id; + +#pragma unroll + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + const std::size_t inp_id = inp_id0 + m_wi * sgSize; + if constexpr (!include_initial) { + local_iscan[m_wi] = + (inp_id < acc_nelems) + ? transformer(input[inp_iter_offset + + inp_indexer(s0 + s1 * inp_id)]) + : identity; + } + else { + // shift input to the left by a single element relative to + // output + local_iscan[m_wi] = + (inp_id < acc_nelems && inp_id > 0) + ? transformer( + input[inp_iter_offset + + inp_indexer((s0 + s1 * inp_id) - 1)]) + : identity; + } + } + + // change layout from striped to blocked + { + { + const std::uint32_t local_offset0 = lid * n_wi; +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + slm_iscan_tmp[local_offset0 + i] = local_iscan[i]; + } + + it.barrier(sycl::access::fence_space::local_space); + } + + { + const std::uint32_t block_offset = + sgroup_id * sgSize * n_wi; + const std::uint32_t disp0 = lane_id * n_wi; +#pragma unroll + for (nwiT i = 0; i < n_wi; ++i) { + const std::uint32_t disp = disp0 + i; + + // disp == lane_id1 + i1 * sgSize; + const std::uint32_t i1 = disp / sgSize; + const std::uint32_t lane_id1 = disp - i1 * sgSize; + + const std::uint32_t disp_exchanged = + (lane_id1 * n_wi + i1); + + local_iscan[i] = + slm_iscan_tmp[block_offset + disp_exchanged]; + } + + it.barrier(sycl::access::fence_space::local_space); + } + } + +#pragma unroll + for (nwiT m_wi = 1; m_wi < n_wi; ++m_wi) { + local_iscan[m_wi] = + scan_op(local_iscan[m_wi], local_iscan[m_wi - 1]); + } + // local_iscan is now result of + // inclusive scan of locally stored inputs + + outputT wg_iscan_val; + if constexpr (can_use_inclusive_scan_over_group<ScanOpT, + outputT>::value) + { + wg_iscan_val = sycl::inclusive_scan_over_group( + it.get_group(), local_iscan.back(), scan_op, identity); + } + else { + wg_iscan_val = su_ns::custom_inclusive_scan_over_group( + it.get_group(), sg, slm_iscan_tmp, local_iscan.back(), + identity, scan_op); + // ensure all finished reading from SLM, to avoid race condition + // with subsequent writes into SLM + it.barrier(sycl::access::fence_space::local_space); + } + + slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val; + it.barrier(sycl::access::fence_space::local_space); + const outputT modifier = (lid == 0) ? identity : slm_iscan_tmp[lid]; + +#pragma unroll + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + local_iscan[m_wi] = scan_op(local_iscan[m_wi], modifier); + } + + it.barrier(sycl::access::fence_space::local_space); + + // convert back to blocked layout + { + { + const std::uint32_t local_offset0 = lid * n_wi; +#pragma unroll + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi]; + } + + it.barrier(sycl::access::fence_space::local_space); + } + } + + { + const std::uint32_t block_offset = + sgroup_id * sgSize * n_wi + lane_id; +#pragma unroll + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + const std::uint32_t m_wi_scaled = m_wi * sgSize; + const std::size_t out_id = inp_id0 + m_wi_scaled; + if (out_id < acc_nelems) { + output[out_iter_offset + out_indexer(out_id)] = + slm_iscan_tmp[block_offset + m_wi_scaled]; + } + } + } + }); + }); + + return inc_scan_phase1_ev; +} + +template <typename inputT, + typename outputT, + nwiT n_wi, + typename IterIndexerT, + typename InpIndexerT, + typename OutIndexerT, + typename TransformerT, + typename ScanOpT, + bool include_initial = false> +sycl::event +inclusive_scan_base_step(sycl::queue &exec_q, + const std::uint32_t wg_size, + const std::size_t iter_nelems, + const std::size_t acc_nelems, + const inputT *input, + outputT *output, + const std::size_t s0, + const std::size_t s1, + const IterIndexerT &iter_indexer, + const InpIndexerT &inp_indexer, + const OutIndexerT &out_indexer, + TransformerT transformer, + const ScanOpT &scan_op, + outputT identity, + std::size_t &acc_groups, + const std::vector<sycl::event> &depends = {}) +{ + // For small stride use striped load/store. + // Threshold value chosen experimentally. + if (s1 <= 16) { + return inclusive_scan_base_step_striped< + inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT, + TransformerT, ScanOpT, include_initial>( + exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1, + iter_indexer, inp_indexer, out_indexer, transformer, scan_op, + identity, acc_groups, depends); + } + else { + return inclusive_scan_base_step_blocked< + inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT, + TransformerT, ScanOpT, include_initial>( + exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1, + iter_indexer, inp_indexer, out_indexer, transformer, scan_op, + identity, acc_groups, depends); + } +} + template <typename inputT, typename outputT, nwiT n_wi, @@ -318,7 +574,7 @@ template <typename inputT, typename ScanOpT, bool include_initial> sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q, - const std::size_t wg_size, + const std::uint32_t wg_size, const std::size_t n_elems, const inputT *input, outputT *output, @@ -512,7 +768,7 @@ accumulate_1d_contig_impl(sycl::queue &q, const sycl::device &dev = q.get_device(); if (dev.has(sycl::aspect::cpu)) { constexpr nwiT n_wi_for_cpu = 8; - const std::size_t wg_size = 256; + const std::uint32_t wg_size = 256; comp_ev = inclusive_scan_iter_1d<srcT, dstT, n_wi_for_cpu, NoOpIndexerT, transformerT, AccumulateOpT, include_initial>( @@ -521,7 +777,10 @@ accumulate_1d_contig_impl(sycl::queue &q, } else { constexpr nwiT n_wi_for_gpu = 4; - const std::size_t wg_size = 256; + // base_scan_striped algorithm does not execute correctly + // on HIP device with wg_size > 64 + const std::uint32_t wg_size = + (q.get_backend() == sycl::backend::ext_oneapi_hip) ? 64 : 256; comp_ev = inclusive_scan_iter_1d<srcT, dstT, n_wi_for_gpu, NoOpIndexerT, transformerT, AccumulateOpT, include_initial>( @@ -553,7 +812,7 @@ template <typename inputT, typename ScanOpT, bool include_initial> sycl::event inclusive_scan_iter(sycl::queue &exec_q, - const std::size_t wg_size, + const std::uint32_t wg_size, const std::size_t iter_nelems, const std::size_t acc_nelems, const inputT *input, @@ -914,7 +1173,7 @@ accumulate_strided_impl(sycl::queue &q, sycl::event comp_ev; if (dev.has(sycl::aspect::cpu)) { constexpr nwiT n_wi_for_cpu = 8; - const std::size_t wg_size = 256; + const std::uint32_t wg_size = 256; comp_ev = inclusive_scan_iter<srcT, dstT, n_wi_for_cpu, InpIndexerT, OutIndexerT, InpIndexerT, OutIndexerT, @@ -925,7 +1184,10 @@ accumulate_strided_impl(sycl::queue &q, } else { constexpr nwiT n_wi_for_gpu = 4; - const std::size_t wg_size = 256; + // base_scan_striped algorithm does not execute correctly + // on HIP device with wg_size > 64 + const std::uint32_t wg_size = + (q.get_backend() == sycl::backend::ext_oneapi_hip) ? 64 : 256; comp_ev = inclusive_scan_iter<srcT, dstT, n_wi_for_gpu, InpIndexerT, OutIndexerT, InpIndexerT, OutIndexerT, @@ -970,7 +1232,7 @@ std::size_t cumsum_val_contig_impl(sycl::queue &q, const sycl::device &dev = q.get_device(); if (dev.has(sycl::aspect::cpu)) { constexpr nwiT n_wi_for_cpu = 8; - const std::size_t wg_size = 256; + const std::uint32_t wg_size = 256; comp_ev = inclusive_scan_iter_1d<maskT, cumsumT, n_wi_for_cpu, NoOpIndexerT, transformerT, AccumulateOpT, include_initial>( @@ -979,7 +1241,10 @@ std::size_t cumsum_val_contig_impl(sycl::queue &q, } else { constexpr nwiT n_wi_for_gpu = 4; - const std::size_t wg_size = 256; + // base_scan_striped algorithm does not execute correctly + // on HIP device with wg_size > 64 + const std::uint32_t wg_size = + (q.get_backend() == sycl::backend::ext_oneapi_hip) ? 64 : 256; comp_ev = inclusive_scan_iter_1d<maskT, cumsumT, n_wi_for_gpu, NoOpIndexerT, transformerT, AccumulateOpT, include_initial>( @@ -1081,7 +1346,7 @@ cumsum_val_strided_impl(sycl::queue &q, sycl::event comp_ev; if (dev.has(sycl::aspect::cpu)) { constexpr nwiT n_wi_for_cpu = 8; - const std::size_t wg_size = 256; + const std::uint32_t wg_size = 256; comp_ev = inclusive_scan_iter_1d<maskT, cumsumT, n_wi_for_cpu, StridedIndexerT, transformerT, AccumulateOpT, include_initial>( @@ -1090,7 +1355,10 @@ cumsum_val_strided_impl(sycl::queue &q, } else { constexpr nwiT n_wi_for_gpu = 4; - const std::size_t wg_size = 256; + // base_scan_striped algorithm does not execute correctly + // on HIP device with wg_size > 64 + const std::uint32_t wg_size = + (q.get_backend() == sycl::backend::ext_oneapi_hip) ? 64 : 256; comp_ev = inclusive_scan_iter_1d<maskT, cumsumT, n_wi_for_gpu, StridedIndexerT, transformerT, AccumulateOpT, include_initial>( diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 52bc50e4e1..a4ace720ce 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -212,29 +212,89 @@ T custom_reduce_over_group(const GroupT &wg, return sycl::group_broadcast(wg, red_val_over_wg, 0); } -template <typename T, typename GroupT, typename LocAccT, typename OpT> -T custom_inclusive_scan_over_group(const GroupT &wg, - LocAccT local_mem_acc, - const T local_val, - const OpT &op) +template <typename GroupT, + typename SubGroupT, + typename LocAccT, + typename T, + typename OpT> +T custom_inclusive_scan_over_group(GroupT &&wg, + SubGroupT &&sg, + LocAccT &&local_mem_acc, + const T &local_val, + const T &identity, + OpT &&op) { const std::uint32_t local_id = wg.get_local_id(0); const std::uint32_t wgs = wg.get_local_range(0); - local_mem_acc[local_id] = local_val; + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sgSize = sg.get_local_range()[0]; + + T scan_val = local_val; + for (std::uint32_t step = 1; step < sgSize; step *= 2) { + const bool advanced_lane = (lane_id >= step); + const std::uint32_t src_lane_id = + (advanced_lane ? lane_id - step : lane_id); + const T modifier = sycl::select_from_group(sg, scan_val, src_lane_id); + if (advanced_lane) { + scan_val = op(scan_val, modifier); + } + } + + local_mem_acc[local_id] = scan_val; sycl::group_barrier(wg, sycl::memory_scope::work_group); - if (wg.leader()) { - T scan_val = local_mem_acc[0]; - for (std::uint32_t i = 1; i < wgs; ++i) { - scan_val = op(local_mem_acc[i], scan_val); - local_mem_acc[i] = scan_val; + const std::uint32_t max_sgSize = sg.get_max_local_range()[0]; + const std::uint32_t sgr_id = sg.get_group_id()[0]; + + // now scan + const std::uint32_t n_aggregates = 1 + ((wgs - 1) / max_sgSize); + const bool large_wg = (n_aggregates > max_sgSize); + if (large_wg) { + if (wg.leader()) { + T _scan_val = identity; + for (std::uint32_t i = 1; i <= n_aggregates - max_sgSize; ++i) { + _scan_val = op(local_mem_acc[i * max_sgSize - 1], _scan_val); + local_mem_acc[i * max_sgSize - 1] = _scan_val; + } } + sycl::group_barrier(wg, sycl::memory_scope::work_group); } - // ensure all work-items see the same SLM that leader updated + if (sgr_id == 0) { + const std::uint32_t offset = + (large_wg) ? n_aggregates - max_sgSize : 0u; + const bool in_range = (lane_id < n_aggregates); + const bool in_bounds = in_range && (lane_id > 0 || large_wg); + + T __scan_val = (in_bounds) + ? local_mem_acc[(offset + lane_id) * max_sgSize - 1] + : identity; + for (std::uint32_t step = 1; step < sgSize; step *= 2) { + const bool advanced_lane = (lane_id >= step); + const std::uint32_t src_lane_id = + (advanced_lane ? lane_id - step : lane_id); + const T modifier = + sycl::select_from_group(sg, __scan_val, src_lane_id); + if (advanced_lane && in_range) { + __scan_val = op(__scan_val, modifier); + } + } + if (in_bounds) { + local_mem_acc[(offset + lane_id) * max_sgSize - 1] = __scan_val; + } + } sycl::group_barrier(wg, sycl::memory_scope::work_group); - return local_mem_acc[local_id]; + + if (sgr_id > 0) { + const T modifier = local_mem_acc[sgr_id * max_sgSize - 1]; + scan_val = op(scan_val, modifier); + } + + // ensure all work-items finished reading from SLM + sycl::group_barrier(wg, sycl::memory_scope::work_group); + + return scan_val; } // Reduction functors