Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various masked operations #2428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOr**(M m, V a, V b)</code>: returns `a[i] | b[i]`
or `zero` if `m[i]` is false.

The following three-argument functions may be more efficient than assembling
them from 2-argument functions:

Expand Down Expand Up @@ -2491,6 +2494,24 @@ more efficient on some targets.
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes.
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes.

### Masked reductions

**Note**: Horizontal operations (across lanes of the same vector) such as
reductions are slower than normal SIMD operations and are typically used outside
critical loops.

All ops in this section ignore lanes where `mask=false`. These are equivalent
to, and potentially more efficient than, `GetLane(SumOfLanes(d,
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
elements are false.

* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
jan-wassenberg marked this conversation as resolved.
Show resolved Hide resolved
where `m[i]` is `true`.
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all
lanes where `m[i]` is `true`.
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all
lanes where `m[i]` is `true`.

### Crypto

Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
Expand Down
47 changes: 47 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
// User-specified mask. Mask=false value is zero.
#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
}

#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
Expand Down Expand Up @@ -763,6 +769,9 @@ HWY_API V Or(const V a, const V b) {
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
}

// ------------------------------ MaskedOr
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr)

// ------------------------------ Xor

namespace detail {
Expand Down Expand Up @@ -1678,6 +1687,7 @@ namespace detail {
return sv##OP##_##CHAR##BITS(pg, v); \
}

// TODO: Remove SumOfLanesM in favor of using MaskedReduceSum
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv)
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv)

Expand Down Expand Up @@ -1725,6 +1735,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
return detail::MaxOfLanesM(detail::MakeMask(d), v);
}

#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) {
return detail::SumOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) {
return detail::MinOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) {
return detail::MaxOfLanesM(m, v);
}

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down Expand Up @@ -5056,6 +5085,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
return IfThenElse(IsNegative(v), yes, no);
}
// ------------------------------ IfNegativeThenNegOrUndefIfZero
jan-wassenberg marked this conversation as resolved.
Show resolved Hide resolved

#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#else
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#endif

#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \
}

HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg)

#undef HWY_SVE_NEG_IF

// ------------------------------ AverageRound (ShiftRight)

Expand Down Expand Up @@ -6587,6 +6633,7 @@ HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount,
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMVV_Z
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGPV
Expand Down
26 changes: 26 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
}
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8

#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) {
return ReduceSum(d, IfThenElseZero(m, v));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
return ReduceMin(d, IfThenElse(m, v, Set(d, hwy::PositiveInfOrHighestValue <TFromD<D>>())));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
return ReduceMax(d, IfThenElse(m, v, Set(d, hwy::NegativeInfOrLowestValue<TFromD<D>>())));
}

#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR

// ------------------------------ IsEitherNaN
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_IS_EITHER_NAN
Expand Down Expand Up @@ -7568,6 +7590,10 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

template <class V, class M>
HWY_API V MaskedOr(M m, V a, V b) {
return IfThenElseZero(m, Or(a, b));
}
// ------------------------------ AllBits1/AllBits0
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4755,6 +4755,8 @@ HWY_API T ReduceMax(D d, const VFromD<D> v) {

#undef HWY_RVV_REDUCE

// TODO: add MaskedReduceSum/Min/Max

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down
23 changes: 23 additions & 0 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestMaskedOr {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const MFromD<D> all_true = MaskTrue(d);
const auto v1 = Iota(d, 1);
const auto v2 = Iota(d, 2);

HWY_ASSERT_VEC_EQ(d, Or(v2, v1), MaskedOr(all_true, v1, v2));

const MFromD<D> first_five = FirstN(d, 5);
const Vec<D> v0 = Zero(d);

const Vec<D> v1_exp = IfThenElse(first_five, Or(v2, v1), v0);

HWY_ASSERT_VEC_EQ(d, v1_exp, MaskedOr(first_five, v1, v2));
}
};

HWY_NOINLINE void TestAllMaskedLogical() {
ForAllTypes(ForPartialVectors<TestMaskedOr>());
}

struct TestAllBits {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
Expand Down Expand Up @@ -185,6 +207,7 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskedLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);

HWY_AFTER_TEST();
Expand Down
126 changes: 126 additions & 0 deletions hwy/tests/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,128 @@ HWY_NOINLINE void TestAllSumsOf8() {
ForGEVectors<64, TestSumsOf8>()(uint8_t());
}

struct TestMaskedReduceSum {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected = 0;
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
expected += ConvertScalarTo<T>(i + 2);
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceSum(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceSum() {
ForAllTypes(ForPartialVectors<TestMaskedReduceSum>());
}

struct TestMaskedReduceMin {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected =
ConvertScalarTo<T>(N + 3); // larger than any values in the vector
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
if (expected > ConvertScalarTo<T>(i + 2)) {
expected = ConvertScalarTo<T>(i + 2);
}
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceMin(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceMin() {
ForAllTypes(ForPartialVectors<TestMaskedReduceMin>());
}

struct TestMaskedReduceMax {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected = 0;
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
if (expected < ConvertScalarTo<T>(i + 2)) {
expected = ConvertScalarTo<T>(i + 2);
}
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceMax(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceMax() {
ForAllTypes(ForPartialVectors<TestMaskedReduceMax>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -367,6 +489,10 @@ HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf2);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf4);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8);

HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceSum);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMin);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMax);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading