From 57a6c9bad2bdd92932c781b2621449a9b970d256 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Wed, 23 Oct 2024 13:07:12 -0700 Subject: [PATCH 1/2] add Get/Set for vectors and use them to implement Concat* operators --- hwy/ops/rvv-inl.h | 176 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 164 insertions(+), 12 deletions(-) diff --git a/hwy/ops/rvv-inl.h b/hwy/ops/rvv-inl.h index f65153294f..62ac160f06 100644 --- a/hwy/ops/rvv-inl.h +++ b/hwy/ops/rvv-inl.h @@ -127,6 +127,26 @@ namespace detail { // for code folding X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) +#define HWY_RVV_FOREACH_08_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + // LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. #define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ @@ -275,6 +295,35 @@ namespace detail { // for code folding HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) +// GET/SET + VIRT +#define HWY_RVV_FOREACH_08_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// For the smallest LMUL for each SEW, similar to the LowerHalf operator, we +// provide the Get and Set operator that returns the same vector type. +#define HWY_RVV_FOREACH_08_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) + // EXT + VIRT #define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ @@ -3123,6 +3172,91 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) #undef HWY_RVV_SLIDE_UP #undef HWY_RVV_SLIDE_DOWN +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v, kIndex); /* no AVL */ \ + } +#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + if constexpr (kIndex == 0) { \ + return Trunc(v); \ + } else { \ + static_assert(kIndex == 1); \ + return Trunc(SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT - 1){}))); \ + } \ + } +#define HWY_RVV_GET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + if constexpr (kIndex == 0) { \ + return v; \ + } else { \ + static_assert(kIndex == 1); \ + return SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT){}) / \ + 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, get, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) +#undef HWY_RVV_GET +#undef HWY_RVV_GET_VIRT +#undef HWY_RVV_GET_SMALLEST + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \ + dest, kIndex, v); /* no AVL */ \ + } +#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \ + auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ + auto df2 = \ + HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT - 1){}; \ + if constexpr (kIndex == 0) { \ + return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \ + Lanes(df2)); \ + } else { \ + static_assert(kIndex == 1); \ + return SlideUp(dest, Ext(d, v), Lanes(df2)); \ + } \ + } +#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ + if constexpr (kIndex == 0) { \ + return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, v, Lanes(d) / 2); \ + } else { \ + static_assert(kIndex == 1); \ + return SlideUp(dest, v, Lanes(d) / 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST) +#undef HWY_RVV_SET +#undef HWY_RVV_SET_VIRT +#undef HWY_RVV_SET_SMALLEST + } // namespace detail // ------------------------------ SlideUpLanes @@ -3144,29 +3278,47 @@ HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { // ------------------------------ ConcatUpperLower template -HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V hi_down = detail::SlideDown(hi, half); - return detail::SlideUp(lo, hi_down, half); +HWY_API V ConcatUpperLower(D, const V hi, const V lo) { + const auto lo_lower = detail::Get<0>(lo); + return detail::Set<0>(hi, lo_lower); } // ------------------------------ ConcatLowerLower template -HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { - return detail::SlideUp(lo, hi, Lanes(d) / 2); +HWY_API V ConcatLowerLower(D, const V hi, const V lo) { + const auto hi_lower = detail::Get<0>(hi); + return detail::Set<1>(lo, hi_lower); } // ------------------------------ ConcatUpperUpper template -HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V hi_down = detail::SlideDown(hi, half); - const V lo_down = detail::SlideDown(lo, half); - return detail::SlideUp(lo_down, hi_down, half); +HWY_API V ConcatUpperUpper(D, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(lo); + return detail::Set<0>(hi, lo_upper); } // ------------------------------ ConcatLowerUpper -template +namespace detail { + +// Only getting a full register is a no-op. +template +constexpr bool IsGetNoOp(D d) { + return d.Pow2() >= 0; +} + +} // namespace detail + +template ())>* = nullptr> +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(lo); + const auto hi_lower = detail::Get<0>(hi); + const auto undef = Undefined(d); + return detail::Set<1>(detail::Set<0>(undef, lo_upper), hi_lower); +} + +template ())>* = nullptr> HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { const size_t half = Lanes(d) / 2; const V lo_down = detail::SlideDown(lo, half); From 914cb69c58c814e70d757ddc4d158e7c27ee43e8 Mon Sep 17 00:00:00 2001 From: John Platts Date: Mon, 2 Dec 2024 13:20:09 -0600 Subject: [PATCH 2/2] Made changes to RVV Concat, Combine, ZeroExtendVector, and UpperHalf ops --- hwy/ops/rvv-inl.h | 116 ++++++++++++++++++++++++++-------------------- 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/hwy/ops/rvv-inl.h b/hwy/ops/rvv-inl.h index 62ac160f06..3df8536fde 100644 --- a/hwy/ops/rvv-inl.h +++ b/hwy/ops/rvv-inl.h @@ -3183,10 +3183,9 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) SHIFT, MLEN, NAME, OP) \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - if constexpr (kIndex == 0) { \ - return Trunc(v); \ - } else { \ - static_assert(kIndex == 1); \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \ + else { \ return Trunc(SlideDown( \ v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ SHIFT - 1){}))); \ @@ -3196,10 +3195,9 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) SHIFT, MLEN, NAME, OP) \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - if constexpr (kIndex == 0) { \ - return v; \ - } else { \ - static_assert(kIndex == 1); \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \ + else { \ return SlideDown( \ v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ SHIFT){}) / \ @@ -3213,6 +3211,23 @@ HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) #undef HWY_RVV_GET_VIRT #undef HWY_RVV_GET_SMALLEST +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +Get(D d, VFromD v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + + const AdjustSimdTagToMinVecPow2> dh; + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { + (void)dh; + return Get(v); + } + else { + const size_t slide_down_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return ResizeBitCast(dh, SlideDown(v, slide_down_amt)); + } +} + #define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ MLEN, NAME, OP) \ template \ @@ -3226,14 +3241,15 @@ HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ auto df2 = \ HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT - 1){}; \ - if constexpr (kIndex == 0) { \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \ Lanes(df2)); \ - } else { \ - static_assert(kIndex == 1); \ + } \ + else { \ return SlideUp(dest, Ext(d, v), Lanes(df2)); \ } \ } @@ -3242,11 +3258,12 @@ HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ - if constexpr (kIndex == 0) { \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, v, Lanes(d) / 2); \ - } else { \ - static_assert(kIndex == 1); \ + } \ + else { \ return SlideUp(dest, v, Lanes(d) / 2); \ } \ } @@ -3257,6 +3274,23 @@ HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST) #undef HWY_RVV_SET_VIRT #undef HWY_RVV_SET_SMALLEST +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD Set( + D d, VFromD dest, VFromD>> v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + + const AdjustSimdTagToMinVecPow2> dh; + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { + (void)dh; + return Set(dest, v); + } + else { + const size_t slide_up_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return SlideUp(dest, ResizeBitCast(d, v), slide_up_amt); + } +} + } // namespace detail // ------------------------------ SlideUpLanes @@ -3278,58 +3312,37 @@ HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { // ------------------------------ ConcatUpperLower template -HWY_API V ConcatUpperLower(D, const V hi, const V lo) { - const auto lo_lower = detail::Get<0>(lo); - return detail::Set<0>(hi, lo_lower); +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + const auto lo_lower = detail::Get<0>(d, lo); + return detail::Set<0>(d, hi, lo_lower); } // ------------------------------ ConcatLowerLower template -HWY_API V ConcatLowerLower(D, const V hi, const V lo) { - const auto hi_lower = detail::Get<0>(hi); - return detail::Set<1>(lo, hi_lower); +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, lo, hi_lower); } // ------------------------------ ConcatUpperUpper template -HWY_API V ConcatUpperUpper(D, const V hi, const V lo) { - const auto lo_upper = detail::Get<1>(lo); - return detail::Set<0>(hi, lo_upper); +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(d, lo); + return detail::Set<0>(d, hi, lo_upper); } // ------------------------------ ConcatLowerUpper -namespace detail { - -// Only getting a full register is a no-op. -template -constexpr bool IsGetNoOp(D d) { - return d.Pow2() >= 0; -} - -} // namespace detail - -template ())>* = nullptr> -HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { - const auto lo_upper = detail::Get<1>(lo); - const auto hi_lower = detail::Get<0>(hi); - const auto undef = Undefined(d); - return detail::Set<1>(detail::Set<0>(undef, lo_upper), hi_lower); -} - -template ())>* = nullptr> +template HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V lo_down = detail::SlideDown(lo, half); - return detail::SlideUp(lo_down, hi, half); + const auto lo_upper = detail::Get<1>(d, lo); + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, ResizeBitCast(d, lo_upper), hi_lower); } // ------------------------------ Combine template HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { - return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi), - Lanes(d2) / 2); + return detail::Set<1>(d2, ResizeBitCast(d2, lo), hi); } // ------------------------------ ZeroExtendVector @@ -3376,8 +3389,9 @@ HWY_API VFromD>> LowerHalf(const V v) { } template -HWY_API VFromD UpperHalf(const DH d2, const VFromD> v) { - return LowerHalf(d2, detail::SlideDown(v, Lanes(d2))); +HWY_API VFromD UpperHalf(const DH /*d2*/, const VFromD> v) { + const Twice d; + return detail::Get<1>(d, v); } // ================================================== SWIZZLE