Skip to content

Commit

Permalink
Merge pull request #2362 from lsrcz:concat
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710909340
  • Loading branch information
copybara-github committed Dec 31, 2024
2 parents f754bd6 + 914cb69 commit 306e46d
Showing 1 changed file with 181 additions and 15 deletions.
196 changes: 181 additions & 15 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ namespace detail { // for code folding
X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \
X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP)

#define HWY_RVV_FOREACH_08_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP)

#define HWY_RVV_FOREACH_16_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP)

#define HWY_RVV_FOREACH_32_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \
X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \
X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP)

#define HWY_RVV_FOREACH_64_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \
X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \
X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP)

// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH.
#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \
Expand Down Expand Up @@ -275,6 +295,35 @@ namespace detail { // for code folding
HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \
HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP)

// GET/SET + VIRT
#define HWY_RVV_FOREACH_08_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP)

#define HWY_RVV_FOREACH_16_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP)

#define HWY_RVV_FOREACH_32_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP)

#define HWY_RVV_FOREACH_64_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP)

// For the smallest LMUL for each SEW, similar to the LowerHalf operator, we
// provide the Get and Set operator that returns the same vector type.
#define HWY_RVV_FOREACH_08_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP)

#define HWY_RVV_FOREACH_16_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP)

#define HWY_RVV_FOREACH_32_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP)

#define HWY_RVV_FOREACH_64_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \
X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP)

// EXT + VIRT
#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \
HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \
Expand Down Expand Up @@ -3123,6 +3172,125 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL)
#undef HWY_RVV_SLIDE_UP
#undef HWY_RVV_SLIDE_DOWN

#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \
v, kIndex); /* no AVL */ \
}
#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \
else { \
return Trunc(SlideDown( \
v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \
SHIFT - 1){}))); \
} \
}
#define HWY_RVV_GET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \
else { \
return SlideDown( \
v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \
SHIFT){}) / \
2); \
} \
}
HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET_SET)
HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, get, _GET_SET_VIRT)
HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST)
#undef HWY_RVV_GET
#undef HWY_RVV_GET_VIRT
#undef HWY_RVV_GET_SMALLEST

template <size_t kIndex, class D>
static HWY_INLINE HWY_MAYBE_UNUSED VFromD<AdjustSimdTagToMinVecPow2<Half<D>>>
Get(D d, VFromD<D> v) {
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1");

const AdjustSimdTagToMinVecPow2<Half<decltype(d)>> dh;
HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) {
(void)dh;
return Get<kIndex>(v);
}
else {
const size_t slide_down_amt =
(dh.Pow2() < DFromV<decltype(v)>().Pow2()) ? Lanes(dh) : (Lanes(d) / 2);
return ResizeBitCast(dh, SlideDown(v, slide_down_amt));
}
}

#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \
return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \
dest, kIndex, v); /* no AVL */ \
}
#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \
auto df2 = \
HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT - 1){}; \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \
Lanes(df2)); \
} \
else { \
return SlideUp(dest, Ext(d, v), Lanes(df2)); \
} \
}
#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, v, Lanes(d) / 2); \
} \
else { \
return SlideUp(dest, v, Lanes(d) / 2); \
} \
}
HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET)
HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT)
HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST)
#undef HWY_RVV_SET
#undef HWY_RVV_SET_VIRT
#undef HWY_RVV_SET_SMALLEST

template <size_t kIndex, class D>
static HWY_INLINE HWY_MAYBE_UNUSED VFromD<D> Set(
D d, VFromD<D> dest, VFromD<AdjustSimdTagToMinVecPow2<Half<D>>> v) {
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1");

const AdjustSimdTagToMinVecPow2<Half<decltype(d)>> dh;
HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) {
(void)dh;
return Set<kIndex>(dest, v);
}
else {
const size_t slide_up_amt =
(dh.Pow2() < DFromV<decltype(v)>().Pow2()) ? Lanes(dh) : (Lanes(d) / 2);
return SlideUp(dest, ResizeBitCast(d, v), slide_up_amt);
}
}

} // namespace detail

// ------------------------------ SlideUpLanes
Expand All @@ -3145,39 +3313,36 @@ HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
// ------------------------------ ConcatUpperLower
template <class D, class V>
HWY_API V ConcatUpperLower(D d, const V hi, const V lo) {
const size_t half = Lanes(d) / 2;
const V hi_down = detail::SlideDown(hi, half);
return detail::SlideUp(lo, hi_down, half);
const auto lo_lower = detail::Get<0>(d, lo);
return detail::Set<0>(d, hi, lo_lower);
}

// ------------------------------ ConcatLowerLower
template <class D, class V>
HWY_API V ConcatLowerLower(D d, const V hi, const V lo) {
return detail::SlideUp(lo, hi, Lanes(d) / 2);
const auto hi_lower = detail::Get<0>(d, hi);
return detail::Set<1>(d, lo, hi_lower);
}

// ------------------------------ ConcatUpperUpper
template <class D, class V>
HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) {
const size_t half = Lanes(d) / 2;
const V hi_down = detail::SlideDown(hi, half);
const V lo_down = detail::SlideDown(lo, half);
return detail::SlideUp(lo_down, hi_down, half);
const auto lo_upper = detail::Get<1>(d, lo);
return detail::Set<0>(d, hi, lo_upper);
}

// ------------------------------ ConcatLowerUpper
template <class D, class V>
HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) {
const size_t half = Lanes(d) / 2;
const V lo_down = detail::SlideDown(lo, half);
return detail::SlideUp(lo_down, hi, half);
const auto lo_upper = detail::Get<1>(d, lo);
const auto hi_lower = detail::Get<0>(d, hi);
return detail::Set<1>(d, ResizeBitCast(d, lo_upper), hi_lower);
}

// ------------------------------ Combine
template <class D2, class V>
HWY_API VFromD<D2> Combine(D2 d2, const V hi, const V lo) {
return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi),
Lanes(d2) / 2);
return detail::Set<1>(d2, ResizeBitCast(d2, lo), hi);
}

// ------------------------------ ZeroExtendVector
Expand Down Expand Up @@ -3224,8 +3389,9 @@ HWY_API VFromD<Half<DFromV<V>>> LowerHalf(const V v) {
}

template <class DH>
HWY_API VFromD<DH> UpperHalf(const DH d2, const VFromD<Twice<DH>> v) {
return LowerHalf(d2, detail::SlideDown(v, Lanes(d2)));
HWY_API VFromD<DH> UpperHalf(const DH /*d2*/, const VFromD<Twice<DH>> v) {
const Twice<DH> d;
return detail::Get<1>(d, v);
}

// ================================================== SWIZZLE
Expand Down

0 comments on commit 306e46d

Please sign in to comment.