Skip to content

Commit

Permalink
Merge pull request #2430 from cambridgeconsultants:cc_up_set_load_sto…
Browse files Browse the repository at this point in the history
…re_count_operations

PiperOrigin-RevId: 721470975
  • Loading branch information
copybara-github committed Jan 30, 2025
2 parents 8e1476f + 6fe29f9 commit e96d4d3
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 0 deletions.
20 changes: 20 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ for comparisons, for example `Lt` instead of `operator<`.
the result, with `t0` in the least-significant (lowest-indexed) lane of each
128-bit block and `tK` in the most-significant (highest-indexed) lane of
each 128-bit block: `{t0, t1, ..., tK}`
* <code>V **MaskedSetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else `no[i]`.
* <code>V **MaskedSet**(D d, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else 0.

### Getting/setting lanes

Expand Down Expand Up @@ -1065,6 +1069,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
leading zeros in each lane. For any lanes where ```a[i]``` is zero,
```sizeof(TFromV<V>) * 8``` is returned in the corresponding result lanes.

* `V`: `{u,i}` \
<code>V **MaskedLeadingZeroCount**(M m, V a)</code>: returns the
result of LeadingZeroCount where `m[i]` is true, and zero otherwise.

* `V`: `{u,i}` \
<code>V **TrailingZeroCount**(V a)</code>: returns the number of
trailing zeros in each lane. For any lanes where ```a[i]``` is zero,
Expand All @@ -1079,6 +1087,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
```HighestValue<MakeSigned<TFromV<V>>>()``` is returned in the
corresponding result lanes.

* <code>bool **AllBits1**(D, V v)</code>: returns whether all bits are set.

* <code>bool **AllBits0**(D, V v)</code>: returns whether all bits are clear.

The following operate on individual bits within each lane. Note that the
non-operator functions (`And` instead of `&`) must be used for floating-point
types, and on SVE/RVV.
Expand Down Expand Up @@ -1593,6 +1605,10 @@ aligned memory at indices which are not a multiple of the vector length):
lanes from `p` to the first (lowest-index) lanes of the result vector and
fills the remaining lanes with `no`. Like LoadN, this does not fault.

* <code> Vec&lt;D&gt; **InsertIntoUpper**(D d, T* p, V v)</code>: Loads `Lanes(d)/2`
lanes from `p` into the upper lanes of the result vector and the lower half
of `v` into the lower lanes.

#### Store

* <code>void **Store**(Vec&lt;D&gt; v, D, T* aligned)</code>: copies `v[i]`
Expand Down Expand Up @@ -1632,6 +1648,10 @@ aligned memory at indices which are not a multiple of the vector length):
StoreN does not modify any memory past
`p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`.

* <code>void **TruncateStore**(Vec&lt;D&gt; v, D d, T* HWY_RESTRICT p)</code>:
Truncates elements of `v` to type `T` and stores on `p`. It is similar to
performing `TruncateTo` followed by `StoreU`.

#### Interleaved

* <code>void **LoadInterleaved2**(D, const T* p, Vec&lt;D&gt;&amp; v0,
Expand Down
68 changes: 68 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,27 @@ using VFromD = decltype(Set(D(), TFromD<D>()));

using VBF16 = VFromD<ScalableTag<bfloat16_t>>;

// ------------------------------ MaskedSetOr/MaskedSet

#define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(no, m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n)
#undef HWY_SVE_MASKED_SET_OR

#define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_z(m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n)
#undef HWY_SVE_MASKED_SET

// ------------------------------ Zero

template <class D>
Expand Down Expand Up @@ -2257,6 +2278,37 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {

#endif // HWY_TARGET != HWY_SVE2_128

// Truncate to smaller size and store
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
template <size_t N, int kPow2> \
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \
}

#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8)
#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16)
#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32)

HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w)

#undef HWY_SVE_STORE_TRUNCATED

// ------------------------------ Load/Store

// SVE only requires lane alignment, not natural alignment of the entire
Expand Down Expand Up @@ -6442,6 +6494,22 @@ HWY_API V HighestSetBitIndex(V v) {
return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v)));
}

#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
const DFromV<decltype(v)> d; \
return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \
}

HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount,
clz)
#undef HWY_SVE_LEADING_ZERO_COUNT

// ================================================== END MACROS
#undef HWY_SVE_ALL_PTRUE
#undef HWY_SVE_D
Expand Down
92 changes: 92 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ HWY_API Vec<D> Inf(D d) {
return BitCast(d, Set(du, max_x2 >> 1));
}

// ------------------------------ MaskedSetOr/MaskedSet

template <class V, typename T = TFromV<V>, typename D = DFromV<V>,
typename M = MFromD<D>>
HWY_API V MaskedSetOr(V no, M m, T a) {
D d;
return IfThenElse(m, Set(d, a), no);
}

template <class D, typename V = VFromD<D>, typename M = MFromD<D>,
typename T = TFromD<D>>
HWY_API V MaskedSet(D d, M m, T a) {
return IfThenElseZero(m, Set(d, a));
}

// ------------------------------ ZeroExtendResizeBitCast

// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128
Expand Down Expand Up @@ -336,6 +351,21 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {

#endif // HWY_NATIVE_DEMOTE_MASK_TO

// ------------------------------ InsertIntoUpper
#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_LOAD_HIGHER
#undef HWY_NATIVE_LOAD_HIGHER
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V InsertIntoUpper(D d, T* p, V a) {
Half<D> dh;
const VFromD<decltype(dh)> b = LoadU(dh, p);
return Combine(d, b, LowerHalf(a));
}
#endif // HWY_NATIVE_LOAD_HIGHER

// ------------------------------ CombineMasks

#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -2659,6 +2689,24 @@ HWY_API void StoreN(VFromD<D> v, D d, T* HWY_RESTRICT p,

#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ TruncateStore
#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

template <class D, class T, HWY_IF_T_SIZE_GT_D(D, sizeof(T)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API void TruncateStore(VFromD<D> v, const D /*d*/, T* HWY_RESTRICT p) {
using DTo = Rebind<T, D>;
DTo dsmall;
StoreU(TruncateTo(dsmall, v), dsmall, p);
}

#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ Scatter

#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -3886,6 +3934,21 @@ HWY_API V TrailingZeroCount(V v) {
}
#endif // HWY_NATIVE_LEADING_ZERO_COUNT

// ------------------------------ MaskedLeadingZeroCount
#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \
defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), class M>
HWY_API V MaskedLeadingZeroCount(M m, V v) {
return IfThenElseZero(m, LeadingZeroCount(v));
}
#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT

// ------------------------------ AESRound

// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes.
Expand Down Expand Up @@ -7442,6 +7505,35 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ AllBits1/AllBits0
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
#undef HWY_NATIVE_ALLONES
#else
#define HWY_NATIVE_ALLONES
#endif

template <class V>
HWY_API bool AllBits1(V a) {
const RebindToUnsigned<DFromV<V>> du;
using TU = TFromD<decltype(du)>;
return AllTrue(du, Eq(BitCast(du, a), Set(du, hwy::HighestValue<TU>())));
}
#endif // HWY_NATIVE_ALLONES

#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLZEROS
#undef HWY_NATIVE_ALLZEROS
#else
#define HWY_NATIVE_ALLZEROS
#endif

template <class V>
HWY_API bool AllBits0(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
43 changes: 43 additions & 0 deletions hwy/tests/count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,48 @@ HWY_NOINLINE void TestAllLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestLeadingZeroCount>());
}

struct TestMaskedLeadingZeroCount {
template <class T, class D>
HWY_ATTR_NO_MSAN HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;
using TU = MakeUnsigned<T>;
const RebindToUnsigned<decltype(d)> du;
size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto data = AllocateAligned<T>(N);
auto lzcnt = AllocateAligned<T>(N);
HWY_ASSERT(data && lzcnt);

constexpr T kNumOfBitsInT = static_cast<T>(sizeof(T) * 8);
for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(kNumOfBitsInT - 2);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCount(first_3, Set(d, static_cast<T>(2))));

for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(1);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCount(
first_3, BitCast(d, Set(du, TU{1} << (kNumOfBitsInT - 2)))));
}
};

HWY_NOINLINE void TestAllMaskedLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestMaskedLeadingZeroCount>());
}

template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
static HWY_INLINE T TrailingZeroCountOfValue(T val) {
Expand Down Expand Up @@ -303,6 +345,7 @@ namespace {
HWY_BEFORE_TEST(HwyCountTest);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllPopulationCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllMaskedLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllTrailingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllHighestSetBitIndex);
HWY_AFTER_TEST();
Expand Down
28 changes: 28 additions & 0 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestAllBits {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
auto v0s = Zero(d);
HWY_ASSERT(AllBits0(v0s));
auto v1s = Not(v0s);
HWY_ASSERT(AllBits1(v1s));
const size_t kNumBits = sizeof(T) * 8;
for (size_t i = 0; i < kNumBits; ++i) {
const Vec<D> bit1 = Set(d, static_cast<T>(1ull << i));
const Vec<D> bit2 = Set(d, static_cast<T>(1ull << ((i + 1) % kNumBits)));
const Vec<D> bits12 = Or(bit1, bit2);
HWY_ASSERT(!AllBits1(bit1));
HWY_ASSERT(!AllBits0(bit1));
HWY_ASSERT(!AllBits1(bit2));
HWY_ASSERT(!AllBits0(bit2));
HWY_ASSERT(!AllBits1(bits12));
HWY_ASSERT(!AllBits0(bits12));
}
}
};

HWY_NOINLINE void TestAllAllBits() {
ForIntegerTypes(ForPartialVectors<TestAllBits>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -159,6 +185,8 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);

HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading

0 comments on commit e96d4d3

Please sign in to comment.