@@ -41,7 +41,7 @@ static float sum(ArrayVectorFloat vector) {
4141 var sum = FloatVector .zero (FloatVector .SPECIES_PREFERRED );
4242 int vectorizedLength = FloatVector .SPECIES_PREFERRED .loopBound (vector .length ());
4343
44- // Process the vectorized part
44+ // Process the remainder
4545 for (int i = 0 ; i < vectorizedLength ; i += FloatVector .SPECIES_PREFERRED .length ()) {
4646 FloatVector a = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vector .get (), i );
4747 sum = sum .add (a );
@@ -207,28 +207,64 @@ static float dotProduct256(ArrayVectorFloat v1, int v1offset, ArrayVectorFloat v
207207 return res ;
208208 }
209209
210- static float dotProductPreferred (ArrayVectorFloat v1 , int v1offset , ArrayVectorFloat v2 , int v2offset , int length ) {
210+ static float dotProductPreferred (ArrayVectorFloat va , int vaoffset , ArrayVectorFloat vb , int vboffset , int length ) {
211211 if (length == FloatVector .SPECIES_PREFERRED .length ())
212- return dotPreferred (v1 , v1offset , v2 , v2offset );
212+ return dotPreferred (va , vaoffset , vb , vboffset );
213213
214- final int vectorizedLength = FloatVector .SPECIES_PREFERRED .loopBound (length );
215- FloatVector sum = FloatVector .zero (FloatVector .SPECIES_PREFERRED );
214+ FloatVector sum0 = FloatVector .zero (FloatVector .SPECIES_PREFERRED );
215+ FloatVector sum1 = sum0 ;
216+ FloatVector a0 , a1 , b0 , b1 ;
216217
217- int i = 0 ;
218- // Process the vectorized part
219- for (; i < vectorizedLength ; i += FloatVector .SPECIES_PREFERRED .length ()) {
220- FloatVector a = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , v1 .get (), v1offset + i );
221- FloatVector b = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , v2 .get (), v2offset + i );
222- sum = a .fma (b , sum );
218+ int vectorLength = FloatVector .SPECIES_PREFERRED .length ();
219+
220+ // Unrolled vector loop; for dot product from L1 cache, an unroll factor of 2 generally suffices.
221+ // If we are going to be getting data that's further down the hierarchy but not fetched off disk/network,
222+ // we might want to unroll further, e.g. to 8 (4 sets of a,b,sum with 3-ahead reads seems to work best).
223+ if (length >= vectorLength * 2 )
224+ {
225+ length -= vectorLength * 2 ;
226+ a0 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , va .get (), vaoffset + vectorLength * 0 );
227+ b0 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vb .get (), vboffset + vectorLength * 0 );
228+ a1 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , va .get (), vaoffset + vectorLength * 1 );
229+ b1 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vb .get (), vboffset + vectorLength * 1 );
230+ vaoffset += vectorLength * 2 ;
231+ vboffset += vectorLength * 2 ;
232+ while (length >= vectorLength * 2 )
233+ {
234+ // All instructions in the main loop have no dependencies between them and can be executed in parallel.
235+ length -= vectorLength * 2 ;
236+ sum0 = a0 .fma (b0 , sum0 );
237+ a0 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , va .get (), vaoffset + vectorLength * 0 );
238+ b0 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vb .get (), vboffset + vectorLength * 0 );
239+ sum1 = a1 .fma (b1 , sum1 );
240+ a1 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , va .get (), vaoffset + vectorLength * 1 );
241+ b1 = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vb .get (), vboffset + vectorLength * 1 );
242+ vaoffset += vectorLength * 2 ;
243+ vboffset += vectorLength * 2 ;
244+ }
245+ sum0 = a0 .fma (b0 , sum0 );
246+ sum1 = a1 .fma (b1 , sum1 );
223247 }
248+ sum0 = sum0 .add (sum1 );
224249
225- float res = sum .reduceLanes (VectorOperators .ADD );
250+ // Process the remaining few vectors
251+ while (length >= vectorLength ) {
252+ length -= vectorLength ;
253+ FloatVector a = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , va .get (), vaoffset );
254+ FloatVector b = FloatVector .fromArray (FloatVector .SPECIES_PREFERRED , vb .get (), vboffset );
255+ vaoffset += vectorLength ;
256+ vboffset += vectorLength ;
257+ sum0 = a .fma (b , sum0 );
258+ }
259+
260+ float resVec = sum0 .reduceLanes (VectorOperators .ADD );
261+ float resTail = 0 ;
226262
227263 // Process the tail
228- for (; i < length ; ++ i )
229- res += v1 .get (v1offset + i ) * v2 .get (v2offset + i );
264+ for (; length > 0 ; -- length )
265+ resTail += va .get (vaoffset ++ ) * vb .get (vboffset ++ );
230266
231- return res ;
267+ return resVec + resTail ;
232268 }
233269
234270 static float cosineSimilarity (ArrayVectorFloat v1 , ArrayVectorFloat v2 ) {
0 commit comments