Skip to content

Commit

Permalink
Hand-unroll the SIMD dot product loop (#380)
Browse files Browse the repository at this point in the history
* Improve SIMD vector dot product and its test
  • Loading branch information
blambov authored Dec 31, 2024
1 parent 6b4fc38 commit 7b78c9e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,54 +22,58 @@
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.util.Random;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 2, time = 5)
@Measurement(iterations = 3, time = 10)
@Fork(warmups = 1, value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=true"})
@Measurement(iterations = 3, time = 5)
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=true"})
@State(Scope.Thread)
public class SimilarityBench {

static VectorFloat<?> A_4 = TestUtil.randomVector(new Random(), 4);
static VectorFloat<?> B_4 = TestUtil.randomVector(new Random(), 4);
static VectorFloat<?> A_8 = TestUtil.randomVector(new Random(), 8);
static VectorFloat<?> B_8 = TestUtil.randomVector(new Random(), 8);
static VectorFloat<?> A_16 = TestUtil.randomVector(new Random(), 16);
static VectorFloat<?> B_16 = TestUtil.randomVector(new Random(), 16);
@Param({"4", "8", "16", "1024"})
int size = 1024;

VectorFloat<?> A, B;

static

@Benchmark
@BenchmarkMode(Mode.Throughput)
@Threads(8)
public void testDotProduct_4(Blackhole bh) {
bh.consume(VectorUtil.dotProduct(A_4, B_4));
@Setup(Level.Trial)
public void setUp()
{
A = TestUtil.randomVector(new Random(), size);
B = TestUtil.randomVector(new Random(), size);
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@Benchmark
@Threads(8)
public void testDotProduct_8(Blackhole bh) {
bh.consume(VectorUtil.dotProduct(A_8, B_8));
public void testDotProduct8(Blackhole bh) {
bh.consume(VectorUtil.dotProduct(A, B));
}


@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Benchmark
@BenchmarkMode(Mode.Throughput)
@Threads(8)
public void testDotProduct_16(Blackhole bh) {
bh.consume(VectorUtil.dotProduct(A_16, B_16));
@Threads(1)
public void testDotProduct1(Blackhole bh) {
bh.consume(VectorUtil.dotProduct(A, B));
}




public static void main(String[] args) throws Exception {
org.openjdk.jmh.Main.main(args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static float sum(ArrayVectorFloat vector) {
var sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());

// Process the vectorized part
// Process the remainder
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i);
sum = sum.add(a);
Expand Down Expand Up @@ -207,28 +207,64 @@ static float dotProduct256(ArrayVectorFloat v1, int v1offset, ArrayVectorFloat v
return res;
}

static float dotProductPreferred(ArrayVectorFloat v1, int v1offset, ArrayVectorFloat v2, int v2offset, int length) {
static float dotProductPreferred(ArrayVectorFloat va, int vaoffset, ArrayVectorFloat vb, int vboffset, int length) {
if (length == FloatVector.SPECIES_PREFERRED.length())
return dotPreferred(v1, v1offset, v2, v2offset);
return dotPreferred(va, vaoffset, vb, vboffset);

final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length);
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
FloatVector sum0 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
FloatVector sum1 = sum0;
FloatVector a0, a1, b0, b1;

int i = 0;
// Process the vectorized part
for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), v1offset + i);
FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), v2offset + i);
sum = a.fma(b, sum);
int vectorLength = FloatVector.SPECIES_PREFERRED.length();

// Unrolled vector loop; for dot product from L1 cache, an unroll factor of 2 generally suffices.
// If we are going to be getting data that's further down the hierarchy but not fetched off disk/network,
// we might want to unroll further, e.g. to 8 (4 sets of a,b,sum with 3-ahead reads seems to work best).
if (length >= vectorLength * 2)
{
length -= vectorLength * 2;
a0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 0);
b0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 0);
a1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 1);
b1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 1);
vaoffset += vectorLength * 2;
vboffset += vectorLength * 2;
while (length >= vectorLength * 2)
{
// All instructions in the main loop have no dependencies between them and can be executed in parallel.
length -= vectorLength * 2;
sum0 = a0.fma(b0, sum0);
a0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 0);
b0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 0);
sum1 = a1.fma(b1, sum1);
a1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 1);
b1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 1);
vaoffset += vectorLength * 2;
vboffset += vectorLength * 2;
}
sum0 = a0.fma(b0, sum0);
sum1 = a1.fma(b1, sum1);
}
sum0 = sum0.add(sum1);

float res = sum.reduceLanes(VectorOperators.ADD);
// Process the remaining few vectors
while (length >= vectorLength) {
length -= vectorLength;
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset);
FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset);
vaoffset += vectorLength;
vboffset += vectorLength;
sum0 = a.fma(b, sum0);
}

float resVec = sum0.reduceLanes(VectorOperators.ADD);
float resTail = 0;

// Process the tail
for (; i < length; ++i)
res += v1.get(v1offset + i) * v2.get(v2offset + i);
for (; length > 0; --length)
resTail += va.get(vaoffset++) * vb.get(vboffset++);

return res;
return resVec + resTail;
}

static float cosineSimilarity(ArrayVectorFloat v1, ArrayVectorFloat v2) {
Expand Down

0 comments on commit 7b78c9e

Please sign in to comment.