Skip to content

Commit

Permalink
fix: correct performance problem with arm function, it was due to Vec…
Browse files Browse the repository at this point in the history
…tor128.shuffle (DO NOT USE)
  • Loading branch information
lemire committed May 28, 2024
1 parent bc1272b commit c0f1a09
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ You can print the content of a vector register like so:
## Performance tips

- Be careful: `Vector128.Shuffle` is not the same as `Ssse3.Shuffle` nor is `Vector128.Shuffle` the same as `Avx2.Shuffle`. Prefer the latter.
- Similarly `Vector128.Shuffle` is not the same as `AdvSimd.Arm64.VectorTableLookup`, use the latter.

## More reading

Expand Down
1 change: 0 additions & 1 deletion benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ public unsafe void Utf8ValidationRealDataScalar()
}
}


[Benchmark]
[BenchmarkCategory("arm64")]
public unsafe void SIMDUtf8ValidationRealDataArm64()
Expand Down
15 changes: 6 additions & 9 deletions src/UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
int contbytes = 0; // number of continuation bytes in the block
int n4 = 0; // number of 4-byte sequences that start in this block

for (; processedLength + 16 <= inputLength; processedLength += 16)
{

Expand All @@ -817,9 +816,10 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
{
// Contains non-ASCII characters, we need to do non-trivial processing
Vector128<byte> prev1 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 1));
Vector128<byte> byte_1_high = Vector128.Shuffle(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
Vector128<byte> byte_1_low = Vector128.Shuffle(shuf2, (prev1 & v0f));
Vector128<byte> byte_2_high = Vector128.Shuffle(shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f);
// Vector128.Shuffle vs AdvSimd.Arm64.VectorTableLookup: prefer the latter!!!
Vector128<byte> byte_1_high = AdvSimd.Arm64.VectorTableLookup(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
Vector128<byte> byte_1_low = AdvSimd.Arm64.VectorTableLookup (shuf2, (prev1 & v0f));
Vector128<byte> byte_2_high = AdvSimd.Arm64.VectorTableLookup (shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f);
Vector128<byte> sc = AdvSimd.And(AdvSimd.And(byte_1_high, byte_1_low), byte_2_high);
Vector128<byte> prev2 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 2));
Vector128<byte> prev3 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 3));
Expand Down Expand Up @@ -849,13 +849,11 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
}
prevIncomplete = AdvSimd.SubtractSaturate(currentBlock, maxValue);
Vector128<sbyte> largestcont = Vector128.Create((sbyte)-65); // -65 => 0b10111111
contbytes += 16 - AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThan(Vector128.AsSByte(currentBlock), largestcont)).ToScalar();
contbytes += -AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThanOrEqual(Vector128.AsSByte(currentBlock), largestcont)).ToScalar();
Vector128<byte> fourthByteMinusOne = Vector128.Create((byte)(0b11110000u - 1));
n4 += (int)(AdvSimd.Arm64.AddAcross(AdvSimd.SubtractSaturate(currentBlock, fourthByteMinusOne)).ToScalar());
}

asciibytes -= (int)AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThanOrEqual(currentBlock, v80)).ToScalar();

asciibytes -= (sbyte)AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThan(currentBlock, v80)).ToScalar();
}

int totalbyte = processedLength - start_point;
Expand Down Expand Up @@ -886,7 +884,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
}
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment;

return pInputBuffer + inputLength;
}
public unsafe static byte* GetPointerToFirstInvalidByte(byte* pInputBuffer, int inputLength, out int Utf16CodeUnitCountAdjustment, out int ScalarCodeUnitCountAdjustment)
Expand Down

0 comments on commit c0f1a09

Please sign in to comment.