Skip to content

Commit

Permalink
Fix double shufti reporting false positives
Browse files Browse the repository at this point in the history
Double shufti used to offset one vector, resulting in losing one character
at the end of every vector. This was replaced by a magic value indicating a
match. This meant that if the first char of a pattern fell on the last char of
a vector, double shufti would assume the second character is present and
report a match.
This patch fixes it by keeping the previous vector and feeding its data to the
new one when we shift it, preventing any loss of data.

Signed-off-by: Yoan Picchi <[email protected]>
  • Loading branch information
ypicchi-arm committed Jan 15, 2025
1 parent 4f09e78 commit 3351539
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 35 deletions.
17 changes: 12 additions & 5 deletions src/nfa/arm/shufti.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask

template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> *inout_t1, SuperVector<S> chars) {

const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
Expand All @@ -57,18 +57,25 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
c1_lo.print8("c1_lo");
SuperVector<S> c1_hi = mask1_hi.template pshufb<true>(chars_hi);
c1_hi.print8("c1_hi");
SuperVector<S> t1 = c1_lo | c1_hi;
t1.print8("t1");
SuperVector<S> new_t1 = c1_lo | c1_hi;
// t1 is the match mask for the first char of the patterns
new_t1.print8("t1");

SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi);
c2_hi.print8("c2_hi");
SuperVector<S> t2 = c2_lo | c2_hi;
// t2 is the match mask for the second char of the patterns
t2.print8("t2");
t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)");
SuperVector<S> t = t1 | (t2.template vshr_128_imm<1>());

// offset t1 so it aligns with t2. The hole created by the offset is filled
// with the last elements of the previous t1 so no info is lost.
// If bits with value 0 lines up, it indicate a match.
SuperVector<S> t = (new_t1.alignr(*inout_t1, S-1)) | t2;
t.print8("t");

*inout_t1 = new_t1;

return !t.eq(SuperVector<S>::Ones());
}
17 changes: 12 additions & 5 deletions src/nfa/ppc64el/shufti.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask

template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> *inout_t1, SuperVector<S> chars) {

const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
Expand All @@ -59,18 +59,25 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
c1_lo.print8("c1_lo");
SuperVector<S> c1_hi = mask1_hi.template pshufb<true>(chars_hi);
c1_hi.print8("c1_hi");
SuperVector<S> t1 = c1_lo | c1_hi;
t1.print8("t1");
SuperVector<S> new_t1 = c1_lo | c1_hi;
// t1 is the match mask for the first char of the patterns
new_t1.print8("t1");

SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi);
c2_hi.print8("c2_hi");
SuperVector<S> t2 = c2_lo | c2_hi;
// t2 is the match mask for the second char of the patterns
t2.print8("t2");
t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)");
SuperVector<S> t = t1 | (t2.template vshr_128_imm<1>());

// offset t1 so it aligns with t2. The hole created by the offset is filled
// with the last elements of the previous t1 so no info is lost.
// If bits with value 0 lines up, it indicate a match.
SuperVector<S> t = (new_t1.alignr(*inout_t1, S-1)) | t2;
t.print8("t");

*inout_t1 = new_t1;

return t.eq(SuperVector<S>::Ones());
}
19 changes: 12 additions & 7 deletions src/nfa/shufti_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static really_inline
const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask_hi, SuperVector<S> chars);
template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars);
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> *inout_first_char_mask, SuperVector<S> chars);

#if defined(VS_SIMDE_BACKEND)
#include "x86/shufti.hpp"
Expand Down Expand Up @@ -82,11 +82,13 @@ const u8 *revBlock(SuperVector<S> mask_lo, SuperVector<S> mask_hi, SuperVector<S

template <uint16_t S>
static really_inline
const u8 *fwdBlockDouble(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, const u8 *buf) {
const u8 *fwdBlockDouble(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> *prev_first_char_mask, SuperVector<S> chars, const u8 *buf) {

SuperVector<S> mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, chars);
SuperVector<S> mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, prev_first_char_mask, chars);

return first_zero_match_inverted<S>(buf, mask);
// By shifting first_char_mask instead of the legacy t2 mask, we would report
// on the second char instead of the first. we offset the buf to compensate.
return first_zero_match_inverted<S>(buf-1, mask);
}

template <uint16_t S>
Expand Down Expand Up @@ -216,25 +218,28 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128
__builtin_prefetch(d + 2*64);
__builtin_prefetch(d + 3*64);
__builtin_prefetch(d + 4*64);

SuperVector<S> first_char_mask = SuperVector<S>::Ones();
DEBUG_PRINTF("start %p end %p \n", d, buf_end);
assert(d < buf_end);
if (d + S <= buf_end) {
// peel off first part to cacheline boundary
DEBUG_PRINTF("until aligned %p \n", ROUNDUP_PTR(d, S));
if (!ISALIGNED_N(d, S)) {
SuperVector<S> chars = SuperVector<S>::loadu(d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d);
DEBUG_PRINTF("rv %p \n", rv);
if (rv) return rv;
d = ROUNDUP_PTR(d, S);
}

first_char_mask = SuperVector<S>::Ones();
while(d + S <= buf_end) {
__builtin_prefetch(d + 64);
DEBUG_PRINTF("d %p \n", d);

SuperVector<S> chars = SuperVector<S>::load(d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d);
if (rv) return rv;
d += S;
}
Expand All @@ -245,7 +250,7 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128

if (d != buf_end) {
SuperVector<S> chars = SuperVector<S>::loadu(d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d);
DEBUG_PRINTF("rv %p \n", rv);
if (rv && rv < buf_end) return rv;
}
Expand Down
36 changes: 23 additions & 13 deletions src/nfa/shufti_sve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,46 +153,54 @@ const u8 *rshuftiExec(m128 mask_lo, m128 mask_hi, const u8 *buf,
static really_inline
svbool_t doubleMatched(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf, const svbool_t pg) {
svuint8_t inout_t1, const u8 *buf, const svbool_t pg) {
svuint8_t vec = svld1_u8(pg, buf);

svuint8_t chars_lo = svand_x(svptrue_b8(), vec, (uint8_t)0xf);
svuint8_t chars_hi = svlsr_x(svptrue_b8(), vec, 4);

svuint8_t c1_lo = svtbl(mask1_lo, chars_lo);
svuint8_t c1_hi = svtbl(mask1_hi, chars_hi);
svuint8_t t1 = svorr_x(svptrue_b8(), c1_lo, c1_hi);
svuint8_t new_t1 = svorr_z(svptrue_b8(), c1_lo, c1_hi);

svuint8_t c2_lo = svtbl(mask2_lo, chars_lo);
svuint8_t c2_hi = svtbl(mask2_hi, chars_hi);
svuint8_t t2 = svext(svorr_z(pg, c2_lo, c2_hi), svdup_u8(0), 1);
svuint8_t t2 = svorr_x(svptrue_b8(), c2_lo, c2_hi);

svuint8_t t = svorr_x(svptrue_b8(), t1, t2);
// shift t1 left by one and feeds in the last element from the previous t1
uint8_t last_elem = svlastb(svptrue_b8(), inout_t1);
svuint8_t merged_t1 = svinsr(new_t1, last_elem);
svuint8_t t = svorr_x(svptrue_b8(), merged_t1, t2);
inout_t1 = new_t1;

return svnot_z(svptrue_b8(), svcmpeq(svptrue_b8(), t, (uint8_t)0xff));
}

static really_inline
const u8 *dshuftiOnce(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf, const u8 *buf_end) {
svuint8_t inout_t1, const u8 *buf, const u8 *buf_end) {
DEBUG_PRINTF("start %p end %p\n", buf, buf_end);
assert(buf < buf_end);
DEBUG_PRINTF("l = %td\n", buf_end - buf);
svbool_t pg = svwhilelt_b8_s64(0, buf_end - buf);
svbool_t matched = doubleMatched(mask1_lo, mask1_hi, mask2_lo, mask2_hi,
buf, pg);
return accelSearchCheckMatched(buf, matched);
inout_t1, buf, pg);
// doubleMatched return match position of the second char, but here we
// return the position of the first char, hence the buffer offset
return accelSearchCheckMatched(buf - 1, matched);
}

static really_inline
const u8 *dshuftiLoopBody(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf) {
svuint8_t inout_t1, const u8 *buf) {
DEBUG_PRINTF("start %p end %p\n", buf, buf + svcntb());
svbool_t matched = doubleMatched(mask1_lo, mask1_hi, mask2_lo, mask2_hi,
buf, svptrue_b8());
return accelSearchCheckMatched(buf, matched);
inout_t1, buf, svptrue_b8());
// doubleMatched return match position of the second char, but here we
// return the position of the first char, hence the buffer offset
return accelSearchCheckMatched(buf - 1, matched);
}

static really_inline
Expand All @@ -201,29 +209,31 @@ const u8 *dshuftiSearch(svuint8_t mask1_lo, svuint8_t mask1_hi,
const u8 *buf, const u8 *buf_end) {
assert(buf < buf_end);
size_t len = buf_end - buf;
svuint8_t inout_t1 = svdup_u8(0xff);
if (len <= svcntb()) {
return dshuftiOnce(mask1_lo, mask1_hi,
mask2_lo, mask2_hi, buf, buf_end);
mask2_lo, mask2_hi, inout_t1, buf, buf_end);
}
// peel off first part to align to the vector size
const u8 *aligned_buf = ROUNDUP_PTR(buf, svcntb_pat(SV_POW2));
assert(aligned_buf < buf_end);
if (buf != aligned_buf) {
const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi,
mask2_lo, mask2_hi, buf);
mask2_lo, mask2_hi, inout_t1, buf);
if (ptr) return ptr;
}
buf = aligned_buf;
size_t loops = (buf_end - buf) / svcntb();
DEBUG_PRINTF("loops %zu \n", loops);
for (size_t i = 0; i < loops; i++, buf += svcntb()) {
const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi,
mask2_lo, mask2_hi, buf);
mask2_lo, mask2_hi, inout_t1, buf);
if (ptr) return ptr;
}
DEBUG_PRINTF("buf %p buf_end %p \n", buf, buf_end);
return buf == buf_end ? NULL : dshuftiLoopBody(mask1_lo, mask1_hi,
mask2_lo, mask2_hi,
inout_t1,
buf_end - svcntb());
}

Expand Down
45 changes: 40 additions & 5 deletions src/nfa/x86/shufti.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask

template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> *inout_c1, SuperVector<S> chars) {

const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
Expand All @@ -57,18 +57,53 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
c1_lo.print8("c1_lo");
SuperVector<S> c1_hi = mask1_hi.pshufb(chars_hi);
c1_hi.print8("c1_hi");
SuperVector<S> c1 = c1_lo | c1_hi;
c1.print8("c1");
SuperVector<S> new_c1 = c1_lo | c1_hi;
// c1 is the match mask for the first char of the patterns
new_c1.print8("c1");

SuperVector<S> c2_lo = mask2_lo.pshufb(chars_lo);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.pshufb(chars_hi);
c2_hi.print8("c2_hi");
SuperVector<S> c2 = c2_lo | c2_hi;
// c2 is the match mask for the second char of the patterns
c2.print8("c2");
c2.template vshr_128_imm<1>().print8("c2.vshr_128(1)");
SuperVector<S> c = c1 | (c2.template vshr_128_imm<1>());

// We want to shift the whole vector left by 1 and insert the last element of inout_c1.
// The lack of direct instructions to insert, extract or concatenate vectors make this
// process complicated, so we resign to store and load for now.
uint8_t tmp_buf[2*S];
SuperVector<S> offset_c1;
switch(S) {
case 16:
_mm_storeu_si128(reinterpret_cast<m128 *>(&tmp_buf[0]), inout_c1->u.v128[0]);
_mm_storeu_si128(reinterpret_cast<m128 *>(&tmp_buf[S]), new_c1.u.v128[0]);
offset_c1 = SuperVector<S>(_mm_loadu_si128(reinterpret_cast<const m128 *>(&tmp_buf[S-1])));
break;
#ifdef HAVE_AVX2
case 32:
_mm256_storeu_si256(reinterpret_cast<m256 *>(&tmp_buf[0]), inout_c1->u.v256[0]);
_mm256_storeu_si256(reinterpret_cast<m256 *>(&tmp_buf[S]), new_c1.u.v256[0]);
offset_c1 = SuperVector<S>(_mm256_loadu_si256(reinterpret_cast<const m256 *>(&tmp_buf[S-1])));
break;
#endif
#ifdef HAVE_AVX512
case 64:
_mm512_storeu_si512(reinterpret_cast<m512 *>(&tmp_buf[0]), inout_c1->u.v512[0]);
_mm512_storeu_si512(reinterpret_cast<m512 *>(&tmp_buf[S]), new_c1.u.v512[0]);
offset_c1 = SuperVector<S>(_mm512_load_si512(reinterpret_cast<const m512 *>(&tmp_buf[S-1])));
break;
#endif
}
offset_c1.print8("offset c1");

// offset c1 so it aligns with c2. The hole created by the offset is filled
// with the last elements of the previous c1 so no info is lost.
// If bits with value 0 lines up, it indicate a match.
SuperVector<S> c = offset_c1 | c2;
c.print8("c");

*inout_c1 = new_c1;

return c.eq(SuperVector<S>::Ones());
}

0 comments on commit 3351539

Please sign in to comment.