Skip to content

Commit ab26084

Browse files
committed
Added SME1 support for int8 GEMM and IGEMM operations
1 parent 77eba01 commit ab26084

21 files changed

+6993
-914
lines changed

cmake/DownloadKleidiAI.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ ENDIF()
1818
# LINT.IfChange
1919
INCLUDE(ExternalProject)
2020
ExternalProject_Add(kleidiai
21-
URL https://github.com/ARM-software/kleidiai/archive/8ca226712975f24f13f71d04cda039a0ee9f9e2f.zip
22-
URL_HASH SHA256=42155cfc084bf1f80e9ef486470f949502ea8d1b845b2f1bebd58978a1b540aa
21+
URL https://github.com/ARM-software/kleidiai/archive/bd2e6ae060014035e25bf4986be682762c446c2d.zip
22+
URL_HASH SHA256=6a4a4e16b695fd6add6c361de1ebf3c7226f954ae103bc8d71fe6705a41cfd04
2323
SOURCE_DIR "${CMAKE_BINARY_DIR}/kleidiai-source"
2424
BINARY_DIR "${CMAKE_BINARY_DIR}/kleidiai"
2525
CONFIGURE_COMMAND ""

cmake/gen/neonsme2_microkernels.cmake

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ SET(PROD_NEONSME2_MICROKERNEL_SRCS
2020
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c
2121
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c
2222
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c
23-
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c
24-
src/x8-pack-lh/x8-packlh-igemm-neonsme2.c
25-
src/x8-pack-lh/x8-packlh-neonsme2.c
23+
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c
2624
src/x16-pack-lh/x16-packlh-neonsme2.c)
2725

2826
SET(NON_PROD_NEONSME2_MICROKERNEL_SRCS)

cmake/gen/neonsme_microkernels.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
SET(PROD_NEONSME_MICROKERNEL_SRCS
1313
src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme.c
1414
src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme.c
15+
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme.c
16+
src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme.c
17+
src/x8-pack-lh/x8-packlh-neonsme.c
18+
src/x8-pack-lh/x8-packlh-igemm-neonsme.c
1519
src/x32-pack-lh/x32-packlh-neonsme.c)
1620

1721
SET(NON_PROD_NEONSME_MICROKERNEL_SRCS)

gen/neonsme2_microkernels.bzl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ PROD_NEONSME2_MICROKERNEL_SRCS = [
1616
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c",
1717
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c",
1818
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c",
19-
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c",
20-
"src/x8-pack-lh/x8-packlh-igemm-neonsme2.c",
21-
"src/x8-pack-lh/x8-packlh-neonsme2.c",
19+
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c",
2220
"src/x16-pack-lh/x16-packlh-neonsme2.c",
2321
]
2422

gen/neonsme_microkernels.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ Auto-generated file. Do not edit!
88
PROD_NEONSME_MICROKERNEL_SRCS = [
99
"src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme.c",
1010
"src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme.c",
11+
"src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme.c",
12+
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme.c",
1113
"src/x32-pack-lh/x32-packlh-neonsme.c",
14+
"src/x8-pack-lh/x8-packlh-igemm-neonsme.c",
15+
"src/x8-pack-lh/x8-packlh-neonsme.c",
1216
]
1317

1418
NON_PROD_NEONSME_MICROKERNEL_SRCS = [

src/configs/gemm-config.c

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,11 @@ static void init_pqs8_qc8w_gemm_config(void) {
412412
pqs8_qc8w_gemm_config.init.qs8_qc8w =
413413
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
414414
pqs8_qc8w_gemm_config.pack_weights_and_biases =
415-
xnn_pack_kai_qs8_qc8w_weights_and_biases_sme2;
415+
xnn_pack_kai_qs8_qc8w_weights_and_biases_sme;
416416
pqs8_qc8w_gemm_config.packed_stride_weights_and_biases =
417-
xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme2;
417+
xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme;
418418
pqs8_qc8w_gemm_config.pack_igemm_goki =
419-
(xnn_pack_conv_goki_w_fn)xnn_pack_kai_qs8_conv_goki_w_sme2;
419+
(xnn_pack_conv_goki_w_fn)xnn_pack_kai_qs8_conv_goki_w_sme;
420420
pqs8_qc8w_gemm_config.pack_igemm_kgo =
421421
(xnn_pack_conv_kgo_w_fn)xnn_pack_qs8_conv_kgo_w;
422422
pqs8_qc8w_gemm_config.pack_deconv_goki =
@@ -426,6 +426,39 @@ static void init_pqs8_qc8w_gemm_config(void) {
426426
pqs8_qc8w_gemm_config.nr = nr;
427427
pqs8_qc8w_gemm_config.log2_kr = 2;
428428
#endif // XNN_ENABLE_ARM_SME2
429+
} else if (XNN_ENABLE_ARM_SME && (hardware_config->arch_flags & xnn_arch_arm_sme)) {
430+
#if XNN_ENABLE_ARM_SME
431+
const size_t mr =
432+
xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_mr();
433+
const size_t nr =
434+
xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_nr();
435+
pqs8_qc8w_gemm_config.arch = xnn_arch_arm_sme;
436+
pqs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme);
437+
pqs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(mr)] =
438+
xnn_init_hmp_packed_igemm_ukernel(
439+
(xnn_packed_lhs_igemm_ukernel_fn)
440+
xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme);
441+
pqs8_qc8w_gemm_config.init.qs8_qc8w =
442+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
443+
pqs8_qc8w_gemm_config.pack_weights_and_biases =
444+
xnn_pack_kai_qs8_qc8w_weights_and_biases_sme;
445+
pqs8_qc8w_gemm_config.packed_stride_weights_and_biases =
446+
xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme;
447+
pqs8_qc8w_gemm_config.pack_igemm_goki =
448+
(xnn_pack_conv_goki_w_fn)xnn_pack_kai_qs8_conv_goki_w_sme;
449+
pqs8_qc8w_gemm_config.pack_igemm_kgo =
450+
(xnn_pack_conv_kgo_w_fn)xnn_pack_qs8_conv_kgo_w;
451+
pqs8_qc8w_gemm_config.pack_deconv_goki =
452+
(xnn_pack_deconv_goki_w_fn)xnn_pack_qs8_deconv_goki_w;
453+
pqs8_qc8w_gemm_config.mr = mr;
454+
pqs8_qc8w_gemm_config.mr_packed = mr;
455+
pqs8_qc8w_gemm_config.nr = nr;
456+
pqs8_qc8w_gemm_config.log2_kr = 2;
457+
#endif // XNN_ENABLE_ARM_SME
458+
459+
}
460+
else {
461+
/* No action */
429462
}
430463
assert(pqs8_qc8w_gemm_config.mr <= XNN_MAX_MR);
431464
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI

src/configs/pack-lh-config.c

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
103103

104104
static void init_x8_pack_lh_config(void) {
105105
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
106-
#if XNN_ENABLE_ARM_SME2
106+
#if XNN_ENABLE_ARM_SME
107107
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
108108
assert(hardware_config != NULL);
109-
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
110-
x8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x8_pack_lh_ukernel__neonsme2;
111-
x8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x8_pack_lh_size__neonsme2;
112-
x8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x8_pack_lh_offset__neonsme2;
109+
if ((hardware_config->arch_flags & xnn_arch_arm_sme)) {
110+
x8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x8_pack_lh_ukernel__neonsme;
111+
x8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x8_pack_lh_size__neonsme;
112+
x8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x8_pack_lh_offset__neonsme;
113113
}
114114
#endif // XNN_ENABLE_ARM_SME2
115115
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -129,13 +129,13 @@ const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
129129

130130
static void init_x8_igemm_pack_lh_config(void) {
131131
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
132-
#if XNN_ENABLE_ARM_SME2
132+
#if XNN_ENABLE_ARM_SME
133133
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
134134
assert(hardware_config != NULL);
135-
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
136-
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn = (xnn_pack_lh_igemm_ukernel_fn) xnn_x8_pack_lh_ukernel__igemm_neonsme2;
137-
x8_igemm_pack_lh_config.size_for_igemm_fn = (xnn_pack_lh_igemm_size_fn) xnn_x8_pack_lh_size__igemm_neonsme2;
138-
x8_igemm_pack_lh_config.offset_for_igemm_fn = (xnn_pack_lh_igemm_offset_fn) xnn_x8_pack_lh_offset__igemm_neonsme2;
135+
if ((hardware_config->arch_flags & xnn_arch_arm_sme)) {
136+
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn = (xnn_pack_lh_igemm_ukernel_fn) xnn_x8_pack_lh_ukernel__igemm_neonsme;
137+
x8_igemm_pack_lh_config.size_for_igemm_fn = (xnn_pack_lh_igemm_size_fn) xnn_x8_pack_lh_size__igemm_neonsme;
138+
x8_igemm_pack_lh_config.offset_for_igemm_fn = (xnn_pack_lh_igemm_offset_fn) xnn_x8_pack_lh_offset__igemm_neonsme;
139139
}
140140
#endif // XNN_ENABLE_ARM_SME2
141141
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <stddef.h>
7+
8+
#include "src/xnnpack/microparams.h"
9+
10+
#if XNN_ENABLE_KLEIDIAI
11+
#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
12+
#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h"
13+
#endif // XNN_ENABLE_KLEIDIAI
14+
15+
size_t xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme_get_mr(void) {
16+
#if XNN_ENABLE_KLEIDIAI
17+
return kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa();
18+
#else
19+
assert(
20+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
21+
"`XNN_ENABLE_KLEIDIAI`." &&
22+
0);
23+
return 0;
24+
#endif // XNN_ENABLE_KLEIDIAI
25+
}
26+
27+
size_t xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme_get_nr(void) {
28+
#if XNN_ENABLE_KLEIDIAI
29+
return kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa();
30+
#else
31+
assert(
32+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
33+
"`XNN_ENABLE_KLEIDIAI`." &&
34+
0);
35+
return 0;
36+
#endif // XNN_ENABLE_KLEIDIAI
37+
}
38+
39+
void xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme(
40+
size_t mr, size_t nc, size_t kc, size_t ks, const void* packed_lhs,
41+
const void* restrict w, int8_t* restrict c, size_t cm_stride,
42+
const union xnn_qs8_qc8w_conv_minmax_params* params) {
43+
#if XNN_ENABLE_KLEIDIAI
44+
const size_t kai_kr = 4;
45+
const size_t k = ks * round_up(kc, kai_kr);
46+
47+
// Repackage the params.
48+
struct kai_matmul_requantize32_params kai_params;
49+
kai_params.output_zero_point = params->fp32_scalar.output_zero_point;
50+
kai_params.min_value = (int8_t)params->fp32_scalar.output_min;
51+
kai_params.max_value = (int8_t)params->fp32_scalar.output_max;
52+
53+
kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa(
54+
mr, nc, k, packed_lhs, w, c, cm_stride, sizeof(int8_t), &kai_params);
55+
#else
56+
assert(
57+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
58+
"`XNN_ENABLE_KLEIDIAI`." &&
59+
0);
60+
#endif // XNN_ENABLE_KLEIDIAI
61+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <stddef.h>
7+
8+
#include "src/xnnpack/math.h"
9+
#include "src/xnnpack/microparams.h"
10+
11+
#if XNN_ENABLE_KLEIDIAI
12+
#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
13+
#endif // XNN_ENABLE_KLEIDIAI
14+
15+
16+
size_t xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_mr() {
17+
#if XNN_ENABLE_KLEIDIAI
18+
return kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa();
19+
#else
20+
assert(
21+
"Calling KleidiAI kai_get_mr wrapper, but XNNPACK was compiled without "
22+
"`XNN_ENABLE_KLEIDIAI`." && 0);
23+
return 0;
24+
#endif // XNN_ENABLE_KLEIDIAI
25+
}
26+
27+
size_t xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_nr() {
28+
#if XNN_ENABLE_KLEIDIAI
29+
return kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa();
30+
31+
#else
32+
assert(
33+
"Calling KleidiAI kai_get_nr wrapper, but XNNPACK was compiled without "
34+
"`XNN_ENABLE_KLEIDIAI`." && 0);
35+
return 0;
36+
#endif // XNN_ENABLE_KLEIDIAI
37+
}
38+
39+
// Wraps the `kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme_mopa`
40+
// GEMM microkernel with a name that is compatible with our tooling.
41+
void xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme(
42+
size_t m, size_t n, size_t k, const void* lhs_packed,
43+
const void* rhs_packed, void* dst, size_t dst_stride_row,
44+
size_t dst_stride_col,
45+
const union xnn_qs8_qc8w_conv_minmax_params* minmax_params) {
46+
#if XNN_ENABLE_KLEIDIAI
47+
struct kai_matmul_requantize32_params kai_params;
48+
kai_params.output_zero_point = minmax_params->fp32_scalar.output_zero_point;
49+
kai_params.min_value = minmax_params->fp32_scalar.output_min;
50+
kai_params.max_value = minmax_params->fp32_scalar.output_max;
51+
52+
kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa(
53+
m, n, k / sizeof(int8_t), lhs_packed, rhs_packed, dst, dst_stride_row,
54+
/*dst_stride_col=*/sizeof(int8_t), &kai_params);
55+
#else
56+
assert(
57+
"Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without "
58+
"`XNN_ENABLE_KLEIDIAI`." && 0);
59+
#endif // XNN_ENABLE_KLEIDIAI
60+
}

src/reference/packing.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,7 +2156,7 @@ void xnn_pack_kai_qs4_weights_and_biases(
21562156
}
21572157
}
21582158

2159-
size_t xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme2(
2159+
size_t xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme(
21602160
const struct xnn_gemm_config* gemm_config, size_t k,
21612161
size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) {
21622162
size_t ret_val =
@@ -2175,7 +2175,7 @@ void transpose_weights_x8(const int8_t* in, int8_t* out, size_t height,
21752175
}
21762176
}
21772177

2178-
void xnn_pack_kai_qs8_qc8w_weights_and_biases_sme2(
2178+
void xnn_pack_kai_qs8_qc8w_weights_and_biases_sme(
21792179
uint32_t flags, const struct xnn_gemm_config* gemm_config,
21802180
size_t input_channels, size_t output_channels, size_t groups,
21812181
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
@@ -2560,7 +2560,7 @@ void xnn_pack_kai_qb4_weights_and_biases(
25602560
}
25612561
}
25622562

2563-
void xnn_pack_kai_qs8_conv_goki_w_sme2(
2563+
void xnn_pack_kai_qs8_conv_goki_w_sme(
25642564
size_t g, size_t nc, size_t ks, size_t kc, size_t nr, size_t kr, size_t sr,
25652565
const int8_t* k, const int32_t* b, const float* scale, void* packed_weights,
25662566
size_t extra_bytes, const struct xnn_qs8_packing_params* params) {

0 commit comments

Comments
 (0)