From 6c0f3966d995adc0716a26d70b5c7be36d3be05d Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Thu, 16 Nov 2023 17:00:19 -0600 Subject: [PATCH] Use fma in SIMD Euclidean/cosine --- .../io/github/jbellis/jvector/vector/SimdOps.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index c042e99a2..b0ad99119 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -269,8 +269,8 @@ static float cosineSimilarity(float[] v1, float[] v2) { var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1, i); var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2, i); vsum = a.fma(b, vsum); - vaMagnitude = vaMagnitude.add(a.mul(a)); - vbMagnitude = vbMagnitude.add(b.mul(b)); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); } float sum = vsum.reduceLanes(VectorOperators.ADD); @@ -379,7 +379,7 @@ static float squareDistance64(float[] v1, int v1offset, float[] v2, int v2offset FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_64, v1, v1offset + i); FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_64, v2, v2offset + i); var diff = a.sub(b); - sum = sum.add(diff.mul(diff)); + sum = diff.fma(diff, sum); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -406,7 +406,7 @@ static float squareDistance128(float[] v1, int v1offset, float[] v2, int v2offse FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_128, v1, v1offset + i); FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_128, v2, v2offset + i); var diff = a.sub(b); - sum = sum.add(diff.mul(diff)); + sum = diff.fma(diff, sum); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -434,7 +434,7 @@ static float squareDistance256(float[] v1, int v1offset, float[] v2, int v2offse FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_256, v1, v1offset + i); FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_256, v2, v2offset + i); var diff = a.sub(b); - sum = sum.add(diff.mul(diff)); + sum = diff.fma(diff, sum); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -462,7 +462,7 @@ static float squareDistancePreferred(float[] v1, int v1offset, float[] v2, int v FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1, v1offset + i); FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2, v2offset + i); var diff = a.sub(b); - sum = sum.add(diff.mul(diff)); + sum = diff.fma(diff, sum); } float res = sum.reduceLanes(VectorOperators.ADD);