Skip to content

Commit 7b78c9e

Browse files
authored
Hand-unroll the SIMD dot product loop (#380)
* Improve SIMD vector dot product and its test
1 parent 6b4fc38 commit 7b78c9e

File tree

2 files changed

+80
-40
lines changed

2 files changed

+80
-40
lines changed

jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/SimilarityBench.java

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,54 +22,58 @@
2222
import org.openjdk.jmh.annotations.Benchmark;
2323
import org.openjdk.jmh.annotations.BenchmarkMode;
2424
import org.openjdk.jmh.annotations.Fork;
25+
import org.openjdk.jmh.annotations.Level;
2526
import org.openjdk.jmh.annotations.Measurement;
2627
import org.openjdk.jmh.annotations.Mode;
28+
import org.openjdk.jmh.annotations.OutputTimeUnit;
29+
import org.openjdk.jmh.annotations.Param;
30+
import org.openjdk.jmh.annotations.Scope;
31+
import org.openjdk.jmh.annotations.Setup;
32+
import org.openjdk.jmh.annotations.State;
2733
import org.openjdk.jmh.annotations.Threads;
2834
import org.openjdk.jmh.annotations.Warmup;
2935
import org.openjdk.jmh.infra.Blackhole;
3036

3137
import java.util.Random;
38+
import java.util.concurrent.TimeUnit;
3239

40+
@BenchmarkMode(Mode.Throughput)
41+
@OutputTimeUnit(TimeUnit.SECONDS)
3342
@Warmup(iterations = 2, time = 5)
34-
@Measurement(iterations = 3, time = 10)
35-
@Fork(warmups = 1, value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=true"})
43+
@Measurement(iterations = 3, time = 5)
44+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=true"})
45+
@State(Scope.Thread)
3646
public class SimilarityBench {
3747

38-
static VectorFloat<?> A_4 = TestUtil.randomVector(new Random(), 4);
39-
static VectorFloat<?> B_4 = TestUtil.randomVector(new Random(), 4);
40-
static VectorFloat<?> A_8 = TestUtil.randomVector(new Random(), 8);
41-
static VectorFloat<?> B_8 = TestUtil.randomVector(new Random(), 8);
42-
static VectorFloat<?> A_16 = TestUtil.randomVector(new Random(), 16);
43-
static VectorFloat<?> B_16 = TestUtil.randomVector(new Random(), 16);
48+
@Param({"4", "8", "16", "1024"})
49+
int size = 1024;
4450

51+
VectorFloat<?> A, B;
4552

46-
static
47-
48-
@Benchmark
49-
@BenchmarkMode(Mode.Throughput)
50-
@Threads(8)
51-
public void testDotProduct_4(Blackhole bh) {
52-
bh.consume(VectorUtil.dotProduct(A_4, B_4));
53+
@Setup(Level.Trial)
54+
public void setUp()
55+
{
56+
A = TestUtil.randomVector(new Random(), size);
57+
B = TestUtil.randomVector(new Random(), size);
5358
}
5459

55-
@Benchmark
5660
@BenchmarkMode(Mode.Throughput)
61+
@OutputTimeUnit(TimeUnit.SECONDS)
62+
@Benchmark
5763
@Threads(8)
58-
public void testDotProduct_8(Blackhole bh) {
59-
bh.consume(VectorUtil.dotProduct(A_8, B_8));
64+
public void testDotProduct8(Blackhole bh) {
65+
bh.consume(VectorUtil.dotProduct(A, B));
6066
}
6167

6268

69+
@BenchmarkMode(Mode.AverageTime)
70+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
6371
@Benchmark
64-
@BenchmarkMode(Mode.Throughput)
65-
@Threads(8)
66-
public void testDotProduct_16(Blackhole bh) {
67-
bh.consume(VectorUtil.dotProduct(A_16, B_16));
72+
@Threads(1)
73+
public void testDotProduct1(Blackhole bh) {
74+
bh.consume(VectorUtil.dotProduct(A, B));
6875
}
6976

70-
71-
72-
7377
public static void main(String[] args) throws Exception {
7478
org.openjdk.jmh.Main.main(args);
7579
}

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ static float sum(ArrayVectorFloat vector) {
4141
var sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
4242
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());
4343

44-
// Process the vectorized part
44+
// Process the remainder
4545
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
4646
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i);
4747
sum = sum.add(a);
@@ -207,28 +207,64 @@ static float dotProduct256(ArrayVectorFloat v1, int v1offset, ArrayVectorFloat v
207207
return res;
208208
}
209209

210-
static float dotProductPreferred(ArrayVectorFloat v1, int v1offset, ArrayVectorFloat v2, int v2offset, int length) {
210+
static float dotProductPreferred(ArrayVectorFloat va, int vaoffset, ArrayVectorFloat vb, int vboffset, int length) {
211211
if (length == FloatVector.SPECIES_PREFERRED.length())
212-
return dotPreferred(v1, v1offset, v2, v2offset);
212+
return dotPreferred(va, vaoffset, vb, vboffset);
213213

214-
final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length);
215-
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
214+
FloatVector sum0 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
215+
FloatVector sum1 = sum0;
216+
FloatVector a0, a1, b0, b1;
216217

217-
int i = 0;
218-
// Process the vectorized part
219-
for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
220-
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), v1offset + i);
221-
FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), v2offset + i);
222-
sum = a.fma(b, sum);
218+
int vectorLength = FloatVector.SPECIES_PREFERRED.length();
219+
220+
// Unrolled vector loop; for dot product from L1 cache, an unroll factor of 2 generally suffices.
221+
// If we are going to be getting data that's further down the hierarchy but not fetched off disk/network,
222+
// we might want to unroll further, e.g. to 8 (4 sets of a,b,sum with 3-ahead reads seems to work best).
223+
if (length >= vectorLength * 2)
224+
{
225+
length -= vectorLength * 2;
226+
a0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 0);
227+
b0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 0);
228+
a1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 1);
229+
b1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 1);
230+
vaoffset += vectorLength * 2;
231+
vboffset += vectorLength * 2;
232+
while (length >= vectorLength * 2)
233+
{
234+
// All instructions in the main loop have no dependencies between them and can be executed in parallel.
235+
length -= vectorLength * 2;
236+
sum0 = a0.fma(b0, sum0);
237+
a0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 0);
238+
b0 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 0);
239+
sum1 = a1.fma(b1, sum1);
240+
a1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset + vectorLength * 1);
241+
b1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset + vectorLength * 1);
242+
vaoffset += vectorLength * 2;
243+
vboffset += vectorLength * 2;
244+
}
245+
sum0 = a0.fma(b0, sum0);
246+
sum1 = a1.fma(b1, sum1);
223247
}
248+
sum0 = sum0.add(sum1);
224249

225-
float res = sum.reduceLanes(VectorOperators.ADD);
250+
// Process the remaining few vectors
251+
while (length >= vectorLength) {
252+
length -= vectorLength;
253+
FloatVector a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, va.get(), vaoffset);
254+
FloatVector b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vb.get(), vboffset);
255+
vaoffset += vectorLength;
256+
vboffset += vectorLength;
257+
sum0 = a.fma(b, sum0);
258+
}
259+
260+
float resVec = sum0.reduceLanes(VectorOperators.ADD);
261+
float resTail = 0;
226262

227263
// Process the tail
228-
for (; i < length; ++i)
229-
res += v1.get(v1offset + i) * v2.get(v2offset + i);
264+
for (; length > 0; --length)
265+
resTail += va.get(vaoffset++) * vb.get(vboffset++);
230266

231-
return res;
267+
return resVec + resTail;
232268
}
233269

234270
static float cosineSimilarity(ArrayVectorFloat v1, ArrayVectorFloat v2) {

0 commit comments

Comments
 (0)