Skip to content

Commit 7f4782f

Browse files
JonathanC-ARMxnnpack-bot
authored andcommitted
Copybara import of the project:
-- c69ccdb by Gian Marco Iodice <[email protected]>: Prototype: Add support for fp16 iGEMM with SME2 - Initial prototype to enable fp16 iGEMM with SME2 in conv2d Signed-off-by: Gian Marco Iodice <[email protected]> -- a3537a1 by Gian Marco Iodice <[email protected]>: Include missing files Signed-off-by: Gian Marco Iodice <[email protected]> -- 232826c by Gian Marco Iodice <[email protected]>: Update FP16 iGEMM based on review comments Signed-off-by: Gian Marco Iodice <[email protected]> -- 03bccaa by Jonathan Clohessy <[email protected]>: Updated FP16 iGemm Review with Fixes Signed-off-by: Jonathan Clohessy <[email protected]> -- 9cd6e88 by Jonathan Clohessy <[email protected]>: Fix rebase issues Signed-off-by: Jonathan Clohessy <[email protected]> -- 7eb618d by Misha Gutman <[email protected]>: Added multiple_of to handle all multiples in reductions simply. No significant performance loss: bench/sum_bf16_fp32_4x32_avx512bf16/real_time [256x1x256x1] 1.720µ ± 0% 1.719µ ± 17% ~ (p=0.485 n=6) bench/sum_fp16_fp32_4x32_avx512fp16/real_time [256x1x256x1] 1.744µ ± 3% 1.753µ ± 14% ~ (p=0.310 n=6) bench/sum_uint8_int32_4x64_avx512bw/real_time [256x1x256x1] 1.218µ ± 1% 1.216µ ± 17% ~ (p=0.818 n=6) bench/sum_int8_int32_4x64_avx512bw/real_time [256x1x256x1] 1.217µ ± 0% 1.216µ ± 15% ~ (p=0.699 n=6) bench/sum_fp32_4x16_avx512f/real_time [256x1x256x1] 2.263µ ± 1% 2.268µ ± 0% ~ (p=0.394 n=6) bench/sum_fp32_4x8_avx2/real_time [256x1x256x1] 4.342µ ± 0% 4.357µ ± 0% ~ (p=0.065 n=6) bench/sum_uint8_int32_4x32_avx2/real_time [256x1x256x1] 2.221µ ± 0% 2.285µ ± 8% ~ (p=0.065 n=6) bench/sum_int8_int32_4x32_avx2/real_time [256x1x256x1] 2.219µ ± 1% 2.279µ ± 2% +2.70% (p=0.002 n=6) bench/sum_fp16_fp32_4x16_f16c/real_time [256x1x256x1] 2.344µ ± 0% 2.345µ ± 7% ~ (p=0.485 n=6) bench/sum_uint8_int32_4x16_sse41/real_time [256x1x256x1] 4.318µ ± 0% 4.328µ ± 0% +0.22% (p=0.015 n=6) bench/sum_int8_int32_4x16_sse41/real_time [256x1x256x1] 4.319µ ± 0% 4.325µ ± 1% ~ (p=0.394 n=6) bench/sum_fp32_4x4_sse2/real_time [256x1x256x1] 8.790µ ± 0% 8.795µ ± 0% ~ (p=0.394 n=6) bench/sum_uint8_int32_4x16_sse2/real_time [256x1x256x1] 3.966µ ± 0% 3.995µ ± 0% +0.73% (p=0.002 n=6) bench/sum_int8_int32_4x16_sse2/real_time [256x1x256x1] 5.382µ ± 1% 5.410µ ± 1% +0.52% (p=0.041 n=6) bench/sum_uint8_int32_4x16_ssse3/real_time [256x1x256x1] 3.977µ ± 0% 3.994µ ± 1% +0.44% (p=0.004 n=6) bench/sum_int8_int32_4x16_ssse3/real_time [256x1x256x1] 5.373µ ± 0% 5.412µ ± 2% +0.72% (p=0.002 n=6) PiperOrigin-RevId: 821549068 -- e5cb8c0 by Misha Gutman <[email protected]>: Changed K1_1 strategy for f32 to go with single accumulator and maximally long multiple, this significantly improved performance. Since contiguous case tiles became different from discontiguous changed the naming to not include tiles information. bench/sum_fp32_4x16_avx512f/real_time [256x1x256x1] 2.259µ ± 1% bench/sum_fp32_4x8_avx2/real_time [256x1x256x1] 4.339µ ± 0% bench/sum_fp32_4x4_sse2/real_time [256x1x256x1] 8.787µ ± 1% bench/sum_fp32/real_time [256x1x256x1] 3.255µ ± 7% bench/sum_fp32_avx512f/real_time [256x1x256x1] 1.441µ ± 17% bench/sum_fp32_avx2/real_time [256x1x256x1] 1.761µ ± 14% bench/sum_fp32_sse2/real_time [256x1x256x1] 3.435µ ± 13% bench/sum_fp32/real_time [256x1x256x1] 3.261µ ± 13% bench/sum_bf16_fp32_4x32_avx512bf16/real_time [256x1x256x1] 1.722µ ± 1% bench/sum_bf16_fp32_avx512bf16/real_time [256x1x256x1] 1.703µ ± 1% bench/sum_fp16_fp32_4x32_avx512fp16/real_time [256x1x256x1] 1.749µ ± 0% bench/sum_fp16_fp32_avx512fp16/real_time [256x1x256x1] 1.744µ ± 0% bench/sum_fp16_fp32_4x16_f16c/real_time [256x1x256x1] 2.341µ ± 1% bench/sum_fp16_fp32_f16c/real_time [256x1x256x1] 1.652µ ± 7% PiperOrigin-RevId: 821556723 -- aeeca5d by Dillon Sharlet <[email protected]>: Remove threadpool library and just build threadpool.cc as part of subgraph PiperOrigin-RevId: 821566586 -- 7304027 by Dillon Sharlet <[email protected]>: Disable SME when msan is enabled PiperOrigin-RevId: 821694771 -- 89a72e3 by Dillon Sharlet <[email protected]>: Don't bother disabling KleidiAI if using YNNPACK This causes builds to fail, and it's harmless to leave it enabled. PiperOrigin-RevId: 821704594 -- 0c5edfc by Dillon Sharlet <[email protected]>: Disable SME on older Apple compilers PiperOrigin-RevId: 821708108 -- 9b29972 by Dillon Sharlet <[email protected]>: Fix usage of `sv{ld,st}1_hor_vnum_za32` According to the ACLE documentation, this increments *both* the slice and the pointer by `vnum` vectors. This usage of it treated it as if it only incremented the pointer to read from/write to by 1 vector (but did not change the slice). This is interesting because this code worked on QEMU, but fails on real (Apple M4) hardware. I think this indicates there is a bug in the implementation of these instructions in QEMU. PiperOrigin-RevId: 821730217 -- 0d3dc09 by Dillon Sharlet <[email protected]>: Fix correctness of dot benchmarks for transpose_a kernels PiperOrigin-RevId: 821808685 -- 4b73eb1 by Pedro Gonnet <[email protected]>: Update `pthreadpool` dependency. PiperOrigin-RevId: 821857188 -- 66d084b by Dillon Sharlet <[email protected]>: Fix flaky quantize tests PiperOrigin-RevId: 821867761 -- 6fc5696 by Quentin Khan <[email protected]>: Add missing `gemm_config` `.element_size` initializations. PiperOrigin-RevId: 821984759 -- 923b7f9 by Jonathan Clohessy <[email protected]>: Fix build issues and guard against sme2 specific path Signed-off-by: Jonathan Clohessy <[email protected]> FUTURE_COPYBARA_INTEGRATE_REVIEW=#9005 from JonathanC-ARM:f16_igemm 56ee7cb PiperOrigin-RevId: 821598958
1 parent 645aa39 commit 7f4782f

25 files changed

+1020
-115
lines changed

build_srcs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ MICROKERNEL_DEFS = [
284284
"src/x64-transposec/x64-transposec.inc",
285285
"src/x8-pack-lh/x8-pack-lh.inc",
286286
"src/x8-pack-lh/x8-pack-lh-igemm.inc",
287+
"src/x16-pack-lh/x16-pack-lh-igemm.inc",
287288
"src/x8-packq/x8-packq.inc",
288289
"src/x8-packw/x8-packw.inc",
289290
"src/x8-transposec/x8-transposec.inc",

cmake/gen/neonsme2_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ SET(PROD_NEONSME2_MICROKERNEL_SRCS
1414
src/pf16-gemm/pf16-gemm-32x32c2-minmax-neonsme2.c
1515
src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c
1616
src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c
17+
src/pf16-f16-f16-igemm/pf16-f16-f16-igemm-32x32c2-minmax-neonsme2.c
1718
src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme2.c
1819
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-1x32c4-minmax-neonsme2.c
1920
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme2.c
2021
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c
2122
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c
2223
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c
2324
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c
25+
src/x16-pack-lh/x16-packlh-igemm-neonsme2.c
2426
src/x8-pack-lh/x8-packlh-igemm-neonsme2.c
2527
src/x8-pack-lh/x8-packlh-neonsme2.c
2628
src/x16-pack-lh/x16-packlh-neonsme2.c)

gen/neonsme2_microkernels.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ PROD_NEONSME2_MICROKERNEL_SRCS = [
1010
"src/pf16-gemm/pf16-gemm-32x32c2-minmax-neonsme2.c",
1111
"src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c",
1212
"src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c",
13+
"src/pf16-f16-f16-igemm/pf16-f16-f16-igemm-32x32c2-minmax-neonsme2.c",
1314
"src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme2.c",
1415
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-1x32c4-minmax-neonsme2.c",
1516
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme2.c",
1617
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c",
1718
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c",
1819
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c",
1920
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c",
21+
"src/x16-pack-lh/x16-packlh-igemm-neonsme2.c",
2022
"src/x8-pack-lh/x8-packlh-igemm-neonsme2.c",
2123
"src/x8-pack-lh/x8-packlh-neonsme2.c",
2224
"src/x16-pack-lh/x16-packlh-neonsme2.c",

scripts/generate-tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ tools/generate-gemm-test.py --spec test/qs8-qc4w-gemm-minmax-fp32.yaml --output-
4949
tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc &
5050

5151
### Tests for IGEMM micro-kernels
52+
tools/generate-gemm-test.py --spec test/pf16-f16-igemm-minmax.yaml --output-test test/pf16-f16-igemm-minmax.cc &
5253
tools/generate-gemm-test.py --spec test/f16-igemm-minmax.yaml --output-test test/f16-igemm-minmax.cc &
5354
tools/generate-gemm-test.py --spec test/f16-f32acc-igemm-minmax.yaml --output-test test/f16-f32acc-igemm-minmax.cc &
5455

src/configs/gemm-config.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,17 @@ static void init_pf16_gemm_config(void) {
333333
pf16_gemm_config.arch = xnn_arch_arm_sme2;
334334
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_pf16_gemm_minmax_ukernel_1x32c2__neonsme2);
335335
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_pf16_gemm_minmax_ukernel_32x32c2__neonsme2);
336+
pf16_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(mr)] =
337+
xnn_init_hmp_packed_igemm_ukernel(
338+
(xnn_packed_lhs_igemm_ukernel_fn)
339+
xnn_pf16_f16_igemm_minmax_fp16_ukernel_32x32c2__neonsme2);
336340
pf16_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params;
337341
pf16_gemm_config.pack_weights_and_biases = xnn_pack_kai_f16_weights_and_biases;
338342
pf16_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_f16_weights_and_biases;
343+
pf16_gemm_config.pack_igemm_goki =
344+
(xnn_pack_conv_goki_w_fn)xnn_pack_kai_f16_conv_goki_w_sme2;
345+
pf16_gemm_config.pack_igemm_kgo =
346+
(xnn_pack_conv_kgo_w_fn)xnn_pack_f16_conv_kgo_w;
339347
pf16_gemm_config.mr = mr;
340348
pf16_gemm_config.mr_packed = mr;
341349
pf16_gemm_config.nr = nr;
@@ -5586,6 +5594,7 @@ const struct xnn_gemm_config* xnn_init_pf16_gemm_config() {
55865594
return NULL;
55875595
}
55885596
XNN_INIT_ONCE(pf16_gemm);
5597+
55895598
return pf16_gemm_config.mr ? &pf16_gemm_config : NULL;
55905599
}
55915600

src/configs/pack-lh-config.c

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,27 @@ static struct xnn_pack_lh_config x8_pack_lh_config = {0};
2020
static struct xnn_pack_lh_config x16_pack_lh_config = {0};
2121
static struct xnn_pack_lh_config x32_pack_lh_config = {0};
2222
static struct xnn_pack_lh_config x8_igemm_pack_lh_config = {0};
23+
static struct xnn_pack_lh_config x16_igemm_pack_lh_config = {0};
2324

2425
XNN_INIT_ONCE_GUARD(qp8_pack_lh);
2526
XNN_INIT_ONCE_GUARD(x8_pack_lh);
2627
XNN_INIT_ONCE_GUARD(x16_pack_lh);
2728
XNN_INIT_ONCE_GUARD(x32_pack_lh);
2829
XNN_INIT_ONCE_GUARD(x8_igemm_pack_lh);
30+
XNN_INIT_ONCE_GUARD(x16_igemm_pack_lh);
2931

3032
static void init_qp8_pack_lh_config(void) {
3133
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
32-
qp8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2;
34+
qp8_pack_lh_config.pack_lh_fn =
35+
(xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2;
3336
#else
34-
qp8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__scalar_u1;
37+
qp8_pack_lh_config.pack_lh_fn =
38+
(xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__scalar_u1;
3539
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
36-
qp8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn)xnn_x8_packq_f32qp8_packed_size;
37-
qp8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn)xnn_x8_packq_f32qp8_packed_offset;
40+
qp8_pack_lh_config.size_fn =
41+
(xnn_pack_lh_size_fn)xnn_x8_packq_f32qp8_packed_size;
42+
qp8_pack_lh_config.offset_fn =
43+
(xnn_pack_lh_offset_fn)xnn_x8_packq_f32qp8_packed_offset;
3844
qp8_pack_lh_config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
3945
qp8_pack_lh_config.log2_packed_element_size = 0;
4046
}
@@ -51,13 +57,17 @@ const struct xnn_pack_lh_config* xnn_init_qp8_pack_lh_config() {
5157

5258
static void init_x32_pack_lh_config(void) {
5359
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
54-
#if XNN_ENABLE_ARM_SME2 || XNN_ENABLE_ARM_SME
55-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
60+
#if XNN_ENABLE_ARM_SME2
61+
const struct xnn_hardware_config* hardware_config =
62+
xnn_init_hardware_config();
5663
assert(hardware_config != NULL);
57-
if (hardware_config->arch_flags & xnn_arch_arm_sme) {
58-
x32_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme;
59-
x32_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme;
60-
x32_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x32_pack_lh_offset__neonsme;
64+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
65+
x32_pack_lh_config.pack_lh_fn =
66+
(xnn_pack_lh_ukernel_fn)xnn_x32_pack_lh_ukernel__neonsme;
67+
x32_pack_lh_config.size_fn =
68+
(xnn_pack_lh_size_fn)xnn_x32_pack_lh_size__neonsme;
69+
x32_pack_lh_config.offset_fn =
70+
(xnn_pack_lh_offset_fn)xnn_x32_pack_lh_offset__neonsme;
6171
}
6272
#endif // XNN_ENABLE_ARM_SME2 || XNN_ENABLE_ARM_SME
6373
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -67,7 +77,8 @@ static void init_x32_pack_lh_config(void) {
6777
}
6878

6979
const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
70-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
80+
const struct xnn_hardware_config* hardware_config =
81+
xnn_init_hardware_config();
7182
if (hardware_config == NULL) {
7283
return NULL;
7384
}
@@ -78,12 +89,16 @@ const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
7889
static void init_x16_pack_lh_config(void) {
7990
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
8091
#if XNN_ENABLE_ARM_SME2
81-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
92+
const struct xnn_hardware_config* hardware_config =
93+
xnn_init_hardware_config();
8294
assert(hardware_config != NULL);
83-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
84-
x16_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x16_pack_lh_ukernel__neonsme2;
85-
x16_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x16_pack_lh_size__neonsme2;
86-
x16_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x16_pack_lh_offset__neonsme2;
95+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
96+
x16_pack_lh_config.pack_lh_fn =
97+
(xnn_pack_lh_ukernel_fn)xnn_x16_pack_lh_ukernel__neonsme2;
98+
x16_pack_lh_config.size_fn =
99+
(xnn_pack_lh_size_fn)xnn_x16_pack_lh_size__neonsme2;
100+
x16_pack_lh_config.offset_fn =
101+
(xnn_pack_lh_offset_fn)xnn_x16_pack_lh_offset__neonsme2;
87102
}
88103
#endif // XNN_ENABLE_ARM_SME2
89104
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -93,7 +108,8 @@ static void init_x16_pack_lh_config(void) {
93108
}
94109

95110
const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
96-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
111+
const struct xnn_hardware_config* hardware_config =
112+
xnn_init_hardware_config();
97113
if (hardware_config == NULL) {
98114
return NULL;
99115
}
@@ -104,12 +120,16 @@ const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
104120
static void init_x8_pack_lh_config(void) {
105121
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
106122
#if XNN_ENABLE_ARM_SME2
107-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
123+
const struct xnn_hardware_config* hardware_config =
124+
xnn_init_hardware_config();
108125
assert(hardware_config != NULL);
109-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
110-
x8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x8_pack_lh_ukernel__neonsme2;
111-
x8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x8_pack_lh_size__neonsme2;
112-
x8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x8_pack_lh_offset__neonsme2;
126+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
127+
x8_pack_lh_config.pack_lh_fn =
128+
(xnn_pack_lh_ukernel_fn)xnn_x8_pack_lh_ukernel__neonsme2;
129+
x8_pack_lh_config.size_fn =
130+
(xnn_pack_lh_size_fn)xnn_x8_pack_lh_size__neonsme2;
131+
x8_pack_lh_config.offset_fn =
132+
(xnn_pack_lh_offset_fn)xnn_x8_pack_lh_offset__neonsme2;
113133
}
114134
#endif // XNN_ENABLE_ARM_SME2
115135
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -119,7 +139,8 @@ static void init_x8_pack_lh_config(void) {
119139
}
120140

121141
const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
122-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
142+
const struct xnn_hardware_config* hardware_config =
143+
xnn_init_hardware_config();
123144
if (hardware_config == NULL) {
124145
return NULL;
125146
}
@@ -128,17 +149,21 @@ const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
128149
}
129150

130151
static void init_x8_igemm_pack_lh_config(void) {
131-
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
132-
#if XNN_ENABLE_ARM_SME2
133-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
134-
assert(hardware_config != NULL);
135-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
136-
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn = (xnn_pack_lh_igemm_ukernel_fn) xnn_x8_pack_lh_ukernel__igemm_neonsme2;
137-
x8_igemm_pack_lh_config.size_for_igemm_fn = (xnn_pack_lh_igemm_size_fn) xnn_x8_pack_lh_size__igemm_neonsme2;
138-
x8_igemm_pack_lh_config.offset_for_igemm_fn = (xnn_pack_lh_igemm_offset_fn) xnn_x8_pack_lh_offset__igemm_neonsme2;
139-
}
140-
#endif // XNN_ENABLE_ARM_SME2
141-
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
152+
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
153+
#if XNN_ENABLE_ARM_SME2
154+
const struct xnn_hardware_config* hardware_config =
155+
xnn_init_hardware_config();
156+
assert(hardware_config != NULL);
157+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
158+
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn =
159+
(xnn_pack_lh_igemm_ukernel_fn)xnn_x8_pack_lh_ukernel__igemm_neonsme2;
160+
x8_igemm_pack_lh_config.size_for_igemm_fn =
161+
(xnn_pack_lh_igemm_size_fn)xnn_x8_pack_lh_size__igemm_neonsme2;
162+
x8_igemm_pack_lh_config.offset_for_igemm_fn =
163+
(xnn_pack_lh_igemm_offset_fn)xnn_x8_pack_lh_offset__igemm_neonsme2;
164+
}
165+
#endif // XNN_ENABLE_ARM_SME2
166+
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
142167
x8_igemm_pack_lh_config.log2_input_element_size = 0;
143168
x8_igemm_pack_lh_config.log2_packed_element_size = 0;
144169
}
@@ -152,3 +177,33 @@ const struct xnn_pack_lh_config* xnn_init_x8_igemm_pack_lh_config() {
152177
XNN_INIT_ONCE(x8_igemm_pack_lh);
153178
return &x8_igemm_pack_lh_config;
154179
}
180+
181+
static void init_x16_igemm_pack_lh_config(void) {
182+
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
183+
#if XNN_ENABLE_ARM_SME2
184+
const struct xnn_hardware_config* hardware_config =
185+
xnn_init_hardware_config();
186+
assert(hardware_config != NULL);
187+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
188+
x16_igemm_pack_lh_config.pack_lh_for_igemm_fn =
189+
(xnn_pack_lh_igemm_ukernel_fn)xnn_x16_pack_lh_ukernel__igemm_neonsme2;
190+
x16_igemm_pack_lh_config.size_for_igemm_fn =
191+
(xnn_pack_lh_igemm_size_fn)xnn_x16_pack_lh_size__igemm_neonsme2;
192+
x16_igemm_pack_lh_config.offset_for_igemm_fn =
193+
(xnn_pack_lh_igemm_offset_fn)xnn_x16_pack_lh_offset__igemm_neonsme2;
194+
}
195+
#endif // XNN_ENABLE_ARM_SME2
196+
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
197+
x16_igemm_pack_lh_config.log2_input_element_size = 1;
198+
x16_igemm_pack_lh_config.log2_packed_element_size = 1;
199+
}
200+
201+
const struct xnn_pack_lh_config* xnn_init_x16_igemm_pack_lh_config() {
202+
const struct xnn_hardware_config* hardware_config =
203+
xnn_init_hardware_config();
204+
if (hardware_config == NULL) {
205+
return NULL;
206+
}
207+
XNN_INIT_ONCE(x16_igemm_pack_lh);
208+
return &x16_igemm_pack_lh_config;
209+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <stddef.h>
7+
#include <stdio.h>
8+
9+
#include "src/xnnpack/microparams.h"
10+
11+
#if XNN_ENABLE_KLEIDIAI
12+
#include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
13+
#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
14+
#endif // XNN_ENABLE_KLEIDIAI
15+
16+
size_t xnn_pf16_f16_igemm_minmax_fp16_ukernel_32x32c2__neonsme2_get_mr(void) {
17+
#if XNN_ENABLE_KLEIDIAI
18+
return kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa();
19+
#else
20+
assert(
21+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
22+
"`XNN_ENABLE_KLEIDIAI`." &&
23+
0);
24+
return 0;
25+
#endif // XNN_ENABLE_KLEIDIAI
26+
}
27+
28+
size_t xnn_pf16_f16_igemm_minmax_fp16_ukernel_32x32c2__neonsme2_get_nr(void) {
29+
#if XNN_ENABLE_KLEIDIAI
30+
return kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa();
31+
#else
32+
assert(
33+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
34+
"`XNN_ENABLE_KLEIDIAI`." &&
35+
0);
36+
return 0;
37+
#endif // XNN_ENABLE_KLEIDIAI
38+
}
39+
40+
void xnn_pf16_f16_igemm_minmax_fp16_ukernel_32x32c2__neonsme2(
41+
size_t mr, size_t nc, size_t kc, size_t ks, const void* packed_lhs,
42+
const void* restrict w, xnn_float16* restrict c, size_t cm_stride,
43+
const struct xnn_f16_minmax_params* params) {
44+
#if XNN_ENABLE_KLEIDIAI
45+
const size_t kai_kr = 2;
46+
const size_t k = ks * round_up(kc, kai_kr);
47+
48+
kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(
49+
mr, nc, k, packed_lhs, w, c, cm_stride * sizeof(xnn_float16),
50+
sizeof(xnn_float16), xnn_float16_to_float(params->scalar.min),
51+
xnn_float16_to_float(params->scalar.max));
52+
#else
53+
assert(
54+
"Calling wrapped KleidiAI function, but XNNPACK was compiled without "
55+
"`XNN_ENABLE_KLEIDIAI`." &&
56+
0);
57+
#endif // XNN_ENABLE_KLEIDIAI
58+
}

0 commit comments

Comments
 (0)