diff --git a/hwy/ops/rvv-inl.h b/hwy/ops/rvv-inl.h index e65602c664..bcd850d60f 100644 --- a/hwy/ops/rvv-inl.h +++ b/hwy/ops/rvv-inl.h @@ -127,6 +127,26 @@ namespace detail { // for code folding X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) +#define HWY_RVV_FOREACH_08_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + // LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. #define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ @@ -275,6 +295,35 @@ namespace detail { // for code folding HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) +// GET/SET + VIRT +#define HWY_RVV_FOREACH_08_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// For the smallest LMUL for each SEW, similar to the LowerHalf operator, we +// provide the Get and Set operator that returns the same vector type. +#define HWY_RVV_FOREACH_08_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) + // EXT + VIRT #define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ @@ -3123,6 +3172,125 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) #undef HWY_RVV_SLIDE_UP #undef HWY_RVV_SLIDE_DOWN +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v, kIndex); /* no AVL */ \ + } +#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \ + else { \ + return Trunc(SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT - 1){}))); \ + } \ + } +#define HWY_RVV_GET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \ + else { \ + return SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT){}) / \ + 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, get, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) +#undef HWY_RVV_GET +#undef HWY_RVV_GET_VIRT +#undef HWY_RVV_GET_SMALLEST + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +Get(D d, VFromD v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + + const AdjustSimdTagToMinVecPow2> dh; + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { + (void)dh; + return Get(v); + } + else { + const size_t slide_down_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return ResizeBitCast(dh, SlideDown(v, slide_down_amt)); + } +} + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \ + dest, kIndex, v); /* no AVL */ \ + } +#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ + auto df2 = \ + HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT - 1){}; \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \ + Lanes(df2)); \ + } \ + else { \ + return SlideUp(dest, Ext(d, v), Lanes(df2)); \ + } \ + } +#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, v, Lanes(d) / 2); \ + } \ + else { \ + return SlideUp(dest, v, Lanes(d) / 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST) +#undef HWY_RVV_SET +#undef HWY_RVV_SET_VIRT +#undef HWY_RVV_SET_SMALLEST + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD Set( + D d, VFromD dest, VFromD>> v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + + const AdjustSimdTagToMinVecPow2> dh; + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { + (void)dh; + return Set(dest, v); + } + else { + const size_t slide_up_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return SlideUp(dest, ResizeBitCast(d, v), slide_up_amt); + } +} + } // namespace detail // ------------------------------ SlideUpLanes @@ -3145,39 +3313,36 @@ HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { // ------------------------------ ConcatUpperLower template HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V hi_down = detail::SlideDown(hi, half); - return detail::SlideUp(lo, hi_down, half); + const auto lo_lower = detail::Get<0>(d, lo); + return detail::Set<0>(d, hi, lo_lower); } // ------------------------------ ConcatLowerLower template HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { - return detail::SlideUp(lo, hi, Lanes(d) / 2); + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, lo, hi_lower); } // ------------------------------ ConcatUpperUpper template HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V hi_down = detail::SlideDown(hi, half); - const V lo_down = detail::SlideDown(lo, half); - return detail::SlideUp(lo_down, hi_down, half); + const auto lo_upper = detail::Get<1>(d, lo); + return detail::Set<0>(d, hi, lo_upper); } // ------------------------------ ConcatLowerUpper template HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { - const size_t half = Lanes(d) / 2; - const V lo_down = detail::SlideDown(lo, half); - return detail::SlideUp(lo_down, hi, half); + const auto lo_upper = detail::Get<1>(d, lo); + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, ResizeBitCast(d, lo_upper), hi_lower); } // ------------------------------ Combine template HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { - return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi), - Lanes(d2) / 2); + return detail::Set<1>(d2, ResizeBitCast(d2, lo), hi); } // ------------------------------ ZeroExtendVector @@ -3224,8 +3389,9 @@ HWY_API VFromD>> LowerHalf(const V v) { } template -HWY_API VFromD UpperHalf(const DH d2, const VFromD> v) { - return LowerHalf(d2, detail::SlideDown(v, Lanes(d2))); +HWY_API VFromD UpperHalf(const DH /*d2*/, const VFromD> v) { + const Twice d; + return detail::Get<1>(d, v); } // ================================================== SWIZZLE