Skip to content

Commit

Permalink
some cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
lemire committed May 28, 2024
1 parent 5a99bb2 commit 92cd232
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 460 deletions.
63 changes: 32 additions & 31 deletions benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,17 @@ public unsafe void SIMDUtf8ValidationRealData()
{
if (allLinesUtf8 != null)
{
// RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByte);
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
{
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
// Call the method with additional out parameters within the lambda.
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
return SimdUnicode.UTF8.GetPointerToFirstInvalidByte(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
});
}
}

[Benchmark]
// [BenchmarkCategory("scalar")]
// public unsafe void Utf8ValidationRealDataScalar()
// {
// if (allLinesUtf8 != null)
// {
// RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteScalar);
// }
// }

[BenchmarkCategory("scalar")]
public unsafe void Utf8ValidationRealDataScalar()
{
Expand All @@ -220,18 +217,33 @@ public unsafe void SIMDUtf8ValidationRealDataArm64()
{
if (allLinesUtf8 != null)
{
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteArm64);
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
{
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
// Call the method with additional out parameters within the lambda.
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
return SimdUnicode.UTF8.GetPointerToFirstInvalidByteArm64(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
});
}

}
// [Benchmark]
// [BenchmarkCategory("avx")]
// public unsafe void SIMDUtf8ValidationRealDataAvx2()
// {
// if (allLinesUtf8 != null)
// {
// RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2);
// }
// }

[Benchmark]
[BenchmarkCategory("avx")]
public unsafe void SIMDUtf8ValidationRealDataAvx2()
{
if (allLinesUtf8 != null)
{
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
{
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
// Call the method with additional out parameters within the lambda.
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
return SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
});
}
}

[Benchmark]
[BenchmarkCategory("sse")]
public unsafe void SIMDUtf8ValidationRealDataSse()
Expand All @@ -241,17 +253,6 @@ public unsafe void SIMDUtf8ValidationRealDataSse()
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse);
}
}
/*
// TODO: enable this benchmark when the AVX-512 implementation is ready
[Benchmark]
[BenchmarkCategory("avx512")]
public unsafe void SIMDUtf8ValidationRealDataAvx512()
{
if (allLinesUtf8 != null)
{
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512);
}
}*/

}
public class Program
Expand Down
97 changes: 69 additions & 28 deletions src/UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -715,16 +715,19 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
return pInputBuffer + inputLength;
}

public unsafe static byte* GetPointerToFirstInvalidByteArm64(byte* pInputBuffer, int inputLength)
public unsafe static byte* GetPointerToFirstInvalidByteArm64(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
int processedLength = 0;
int TempUtf16CodeUnitCountAdjustment = 0;
int TempScalarCountAdjustment = 0;

int utf16CodeUnitCountAdjustment = 0, scalarCountAdjustment = 0;
int TailScalarCodeUnitCountAdjustment = 0;
int TailUtf16CodeUnitCountAdjustment = 0;

if (pInputBuffer == null || inputLength <= 0)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer;
}
if (inputLength > 128)
Expand Down Expand Up @@ -793,18 +796,32 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
Vector128<byte> v0f = Vector128.Create((byte)0x0F);
Vector128<byte> v80 = Vector128.Create((byte)0x80);
// Performance note: we could process 64 bytes at a time for better speed in some cases.
int start_point = processedLength;

// The block goes from processedLength to processedLength/16*16.
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
int contbytes = 0; // number of continuation bytes in the block
int n4 = 0; // number of 4-byte sequences that start in this block

for (; processedLength + 16 <= inputLength; processedLength += 16)
{

Vector128<byte> currentBlock = AdvSimd.LoadVector128(pInputBuffer + processedLength);

if (AdvSimd.Arm64.MaxAcross(currentBlock).ToScalar() > 127)
if (AdvSimd.Arm64.MaxAcross(currentBlock).ToScalar() <= 127)
{
// We have an ASCII block, no need to process it, but
// we need to check if the previous block was incomplete.
if (AdvSimd.Arm64.MaxAcross(prevIncomplete).ToScalar() != 0)
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref utf16CodeUnitCountAdjustment, ref scalarCountAdjustment);
int totalbyteasciierror = processedLength - start_point;
var (utfadjustasciierror, scalaradjustasciierror) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyteasciierror);

utf16CodeUnitCountAdjustment = utfadjustasciierror;
scalarCountAdjustment = scalaradjustasciierror;

int off = processedLength >= 3 ? processedLength - 3 : processedLength;
return SimdUnicode.UTF8.RewindAndValidateWithErrors(off, pInputBuffer + off, inputLength - off, ref utf16CodeUnitCountAdjustment, ref scalarCountAdjustment);
}
prevIncomplete = Vector128<byte>.Zero;
}
Expand All @@ -829,52 +846,76 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
// hardware:
if (AdvSimd.Arm64.MaxAcross(Vector128.AsUInt32(error)).ToScalar() != 0)
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref utf16CodeUnitCountAdjustment, ref scalarCountAdjustment);
int off = processedLength > 32 ? processedLength - 32 : processedLength;// this does not backup ff processedlength = 32
byte* invalidBytePointer = SimdUnicode.UTF8.RewindAndValidateWithErrors(off, pInputBuffer + processedLength, inputLength - processedLength, ref TailUtf16CodeUnitCountAdjustment, ref TailScalarCodeUnitCountAdjustment);
utf16CodeUnitCountAdjustment = TailUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TailScalarCodeUnitCountAdjustment;

int totalbyteasciierror = processedLength - start_point;
var (utfadjustasciierror, scalaradjustasciierror) = calculateErrorPathadjust(start_point, processedLength, pInputBuffer, asciibytes, n4, contbytes);

utf16CodeUnitCountAdjustment += utfadjustasciierror;
scalarCountAdjustment += scalaradjustasciierror;

return invalidBytePointer;
}
prevIncomplete = AdvSimd.SubtractSaturate(currentBlock, maxValue);
if (AdvSimd.Arm64.MaxAcross(Vector128.AsUInt32(prevIncomplete)).ToScalar() != 0)
{
// We have an unterminated sequence.
var (totalbyteadjustment, i, tempascii, tempcont, tempn4) = adjustmentFactor(pInputBuffer + processedLength + 32);
processedLength -= i;
n4 += tempn4;
contbytes += tempcont;
}
Vector128<sbyte> largestcont = Vector128.Create((sbyte)-65); // -65 => 0b10111111
contbytes += 16 - AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThan(Vector128.AsSByte(currentBlock), largestcont)).ToScalar();
Vector128<byte> fourthByteMinusOne = Vector128.Create((byte)(0b11110000u - 1));
n4 += (int)(AdvSimd.Arm64.AddAcross(AdvSimd.SubtractSaturate(currentBlock, fourthByteMinusOne)).ToScalar());
}

asciibytes -= (int)AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThanOrEqual(currentBlock, v80)).ToScalar();

}

int totalbyte = processedLength - start_point;
var (utf16adjust, scalaradjust) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyte);

TempUtf16CodeUnitCountAdjustment = utf16adjust;
TempScalarCountAdjustment = scalaradjust;

}
}
// We have processed all the blocks using SIMD, we need to process the remaining bytes.

// Process the remaining bytes with the scalar function

// worst possible case is 4 bytes, where we need to backtrack 3 bytes
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
if (processedLength < inputLength)
{
// We need to possibly backtrack to the start of the last code point
// worst possible case is 4 bytes, where we need to backtrack 3 bytes
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
{
processedLength -= 1;
}
}
}
int TailScalarCodeUnitCountAdjustment = 0;
int TailUtf16CodeUnitCountAdjustment = 0;
byte* invalidBytePointer = SimdUnicode.UTF8.GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out TailUtf16CodeUnitCountAdjustment, out TailScalarCodeUnitCountAdjustment);

byte* invalidBytePointer = SimdUnicode.UTF8.RewindAndValidateWithErrors(32, pInputBuffer + processedLength, inputLength - processedLength, ref TailUtf16CodeUnitCountAdjustment, ref TailScalarCodeUnitCountAdjustment);
if (invalidBytePointer != pInputBuffer + inputLength)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment;

// An invalid byte was found by the scalar function
return invalidBytePointer;
}
}
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment;

return pInputBuffer + inputLength;
}
public unsafe static byte* GetPointerToFirstInvalidByte(byte* pInputBuffer, int inputLength, out int Utf16CodeUnitCountAdjustment, out int ScalarCodeUnitCountAdjustment)
{

// if (AdvSimd.Arm64.IsSupported)
// {
// return GetPointerToFirstInvalidByteArm64(pInputBuffer, inputLength);
// }
if (AdvSimd.Arm64.IsSupported)
{
return GetPointerToFirstInvalidByteArm64(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
}
if (Avx2.IsSupported)
{
return GetPointerToFirstInvalidByteAvx2(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
Expand Down
Loading

0 comments on commit 92cd232

Please sign in to comment.