Skip to content

Commit

Permalink
Merge pull request #42 from simdutf/sse_validation
Browse files Browse the repository at this point in the history
Sse validation
  • Loading branch information
Nick-Nuon authored Jun 12, 2024
2 parents 6ff99ee + a992e33 commit 75dc197
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 59 deletions.
8 changes: 7 additions & 1 deletion benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,13 @@ public unsafe void SIMDUtf8ValidationRealDataSse()
{
if (allLinesUtf8 != null)
{
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse);
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
{
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
// Call the method with additional out parameters within the lambda.
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
return SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
});
}
}

Expand Down
189 changes: 137 additions & 52 deletions src/UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace SimdUnicode
public static class UTF8
{


// Returns &inputBuffer[inputLength] if the input buffer is valid.
/// <summary>
/// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
Expand All @@ -35,11 +36,10 @@ public static class UTF8
{
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength);
}*/
// if (Ssse3.IsSupported)
// {
// return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength);
// }
// return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength);
if (Ssse3.IsSupported)
{
return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength,out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
}

return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);

Expand Down Expand Up @@ -471,15 +471,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
return (utfadjust, scalaradjust);
}

public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength)
public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{

int processedLength = 0;
int TempUtf16CodeUnitCountAdjustment = 0;
int TempScalarCountAdjustment = 0;

if (pInputBuffer == null || inputLength <= 0)
{
utf16CodeUnitCountAdjustment = 0;
scalarCountAdjustment = 0;
return pInputBuffer;
}
if (inputLength > 128)
Expand All @@ -503,24 +501,24 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust

if (processedLength + 16 < inputLength)
{
// We still have work to do!
Vector128<byte> prevInputBlock = Vector128<byte>.Zero;

Vector128<byte> maxValue = Vector128.Create(
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 0b11110000 - 1, 0b11100000 - 1, 0b11000000 - 1);
Vector128<byte> prevIncomplete = Sse2.SubtractSaturate(prevInputBlock, maxValue);

Vector128<byte> prevIncomplete = Sse3.SubtractSaturate(prevInputBlock, maxValue);

Vector128<byte> shuf1 = Vector128.Create(TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
Vector128<byte> shuf1 = Vector128.Create(
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
TOO_SHORT | OVERLONG_2,
TOO_SHORT,
TOO_SHORT | OVERLONG_3 | SURROGATE,
TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4);

Vector128<byte> shuf2 = Vector128.Create(CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
Vector128<byte> shuf2 = Vector128.Create(
CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
CARRY | OVERLONG_2,
CARRY,
CARRY,
Expand All @@ -536,7 +534,8 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000);
Vector128<byte> shuf3 = Vector128.Create(TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
Vector128<byte> shuf3 = Vector128.Create(
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
Expand All @@ -548,24 +547,71 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
Vector128<byte> fourthByte = Vector128.Create((byte)(0b11110000u - 0x80));
Vector128<byte> v0f = Vector128.Create((byte)0x0F);
Vector128<byte> v80 = Vector128.Create((byte)0x80);
/****
* So we want to count the number of 4-byte sequences,
* the number of 4-byte sequences, 3-byte sequences, and
* the number of 2-byte sequences.
* We can do it indirectly. We know how many bytes in total
* we have (length). Let us assume that the length covers
* only complete sequences (we need to adjust otherwise).
* We have that
* length = 4 * n4 + 3 * n3 + 2 * n2 + n1
* where n1 is the number of 1-byte sequences (ASCII),
* n2 is the number of 2-byte sequences, n3 is the number
* of 3-byte sequences, and n4 is the number of 4-byte sequences.
*
* Let ncon be the number of continuation bytes, then we have
* length = n4 + n3 + n2 + ncon + n1
*
* We can solve for n2 and n3 in terms of the other variables:
* n3 = n1 - 2 * n4 + 2 * ncon - length
* n2 = -2 * n1 + n4 - 4 * ncon + 2 * length
* Thus we only need to count the number of continuation bytes,
* the number of ASCII bytes and the number of 4-byte sequences.
*/
////////////
// The *block* here is what begins at processedLength and ends
// at processedLength/16*16 or when an error occurs.
///////////
int start_point = processedLength;

// The block goes from processedLength to processedLength/16*16.
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
int contbytes = 0; // number of continuation bytes in the block
int n4 = 0; // number of 4-byte sequences that start in this block
for (; processedLength + 16 <= inputLength; processedLength += 16)
{

Vector128<byte> currentBlock = Sse2.LoadVector128(pInputBuffer + processedLength);

int mask = Sse2.MoveMask(currentBlock);
Vector128<byte> currentBlock = Avx.LoadVector128(pInputBuffer + processedLength);
int mask = Sse42.MoveMask(currentBlock);
if (mask == 0)
{
// We have an ASCII block, no need to process it, but
// we need to check if the previous block was incomplete.
if (Sse2.MoveMask(prevIncomplete) != 0)
//

if (!Sse41.TestZ(prevIncomplete, prevIncomplete))
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);
int off = processedLength >= 3 ? processedLength - 3 : processedLength;
byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
// So the code is correct up to invalidBytePointer
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int totalbyteasciierror = processedLength - start_point;
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyteasciierror);
return invalidBytePointer;
}
prevIncomplete = Vector128<byte>.Zero;
}
else
else // Contains non-ASCII characters, we need to do non-trivial processing
{
// Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
// Contains non-ASCII characters, we need to do non-trivial processing
Vector128<byte> prev1 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 1));
Vector128<byte> byte_1_high = Ssse3.Shuffle(shuf1, Sse2.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
Expand All @@ -575,54 +621,93 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
Vector128<byte> prev2 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 2));
Vector128<byte> prev3 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 3));
prevInputBlock = currentBlock;

Vector128<byte> isThirdByte = Sse2.SubtractSaturate(prev2, thirdByte);
Vector128<byte> isFourthByte = Sse2.SubtractSaturate(prev3, fourthByte);
Vector128<byte> must23 = Sse2.Or(isThirdByte, isFourthByte);
Vector128<byte> must23As80 = Sse2.And(must23, v80);
Vector128<byte> error = Sse2.Xor(must23As80, sc);
if (Sse2.MoveMask(error) != 0)

if (!Sse42.TestZ(error, error))
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);

byte* invalidBytePointer;
if (processedLength == 0)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
return invalidBytePointer;
}
prevIncomplete = Sse2.SubtractSaturate(currentBlock, maxValue);

prevIncomplete = Sse3.SubtractSaturate(currentBlock, maxValue);

contbytes += (int)Popcnt.PopCount((uint)Sse42.MoveMask(byte_2_high));
// We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
n4 += (int)Popcnt.PopCount((uint)Sse42.MoveMask(Sse42.SubtractSaturate(currentBlock, fourthByte)));
}

// important: we just update asciibytes if there was no error.
// We count the number of ascii bytes in the block using just some simple arithmetic
// and no expensive operation:
asciibytes += (int)(16 - Popcnt.PopCount((uint)mask));
}
}
}
// We have processed all the blocks using SIMD, we need to process the remaining bytes.

// Process the remaining bytes with the scalar function
if (processedLength < inputLength)
{
// We need to possibly backtrack to the start of the last code point
// worst possible case is 4 bytes, where we need to backtrack 3 bytes
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)

// We may still have an error.
if (processedLength < inputLength || !Sse42.TestZ(prevIncomplete, prevIncomplete))
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
byte* invalidBytePointer;
if (processedLength == 0)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);

}
if (invalidBytePointer != pInputBuffer + inputLength)
{
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
return invalidBytePointer;
}
else
{
processedLength -= 1;
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
}
}
int TailScalarCodeUnitCountAdjustment = 0;
int TailUtf16CodeUnitCountAdjustment = 0;
byte* invalidBytePointer = SimdUnicode.UTF8.GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out TailUtf16CodeUnitCountAdjustment, out TailScalarCodeUnitCountAdjustment);
if (invalidBytePointer != pInputBuffer + inputLength)
{
// An invalid byte was found by the scalar function
return invalidBytePointer;
int final_total_bytes_processed = inputLength - start_point;
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, final_total_bytes_processed);
return pInputBuffer + inputLength;
}
}

return pInputBuffer + inputLength;
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}


//
public unsafe static byte* GetPointerToFirstInvalidByteAvx2(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
int processedLength = 0;
Expand Down
Loading

0 comments on commit 75dc197

Please sign in to comment.