Skip to content

Commit 214798d

Browse files
JonathanC-ARMxnnpack-bot
authored andcommitted
Copybara import of the project:
-- c69ccdb by Gian Marco Iodice <[email protected]>: Prototype: Add support for fp16 iGEMM with SME2 - Initial prototype to enable fp16 iGEMM with SME2 in conv2d Signed-off-by: Gian Marco Iodice <[email protected]> -- a3537a1 by Gian Marco Iodice <[email protected]>: Include missing files Signed-off-by: Gian Marco Iodice <[email protected]> -- 232826c by Gian Marco Iodice <[email protected]>: Update FP16 iGEMM based on review comments Signed-off-by: Gian Marco Iodice <[email protected]> -- 03bccaa by Jonathan Clohessy <[email protected]>: Updated FP16 iGemm Review with Fixes Signed-off-by: Jonathan Clohessy <[email protected]> -- 9cd6e88 by Jonathan Clohessy <[email protected]>: Fix rebase issues Signed-off-by: Jonathan Clohessy <[email protected]> -- 7eb618d by Misha Gutman <[email protected]>: Added multiple_of to handle all multiples in reductions simply. No significant performance loss: bench/sum_bf16_fp32_4x32_avx512bf16/real_time [256x1x256x1] 1.720µ ± 0% 1.719µ ± 17% ~ (p=0.485 n=6) bench/sum_fp16_fp32_4x32_avx512fp16/real_time [256x1x256x1] 1.744µ ± 3% 1.753µ ± 14% ~ (p=0.310 n=6) bench/sum_uint8_int32_4x64_avx512bw/real_time [256x1x256x1] 1.218µ ± 1% 1.216µ ± 17% ~ (p=0.818 n=6) bench/sum_int8_int32_4x64_avx512bw/real_time [256x1x256x1] 1.217µ ± 0% 1.216µ ± 15% ~ (p=0.699 n=6) bench/sum_fp32_4x16_avx512f/real_time [256x1x256x1] 2.263µ ± 1% 2.268µ ± 0% ~ (p=0.394 n=6) bench/sum_fp32_4x8_avx2/real_time [256x1x256x1] 4.342µ ± 0% 4.357µ ± 0% ~ (p=0.065 n=6) bench/sum_uint8_int32_4x32_avx2/real_time [256x1x256x1] 2.221µ ± 0% 2.285µ ± 8% ~ (p=0.065 n=6) bench/sum_int8_int32_4x32_avx2/real_time [256x1x256x1] 2.219µ ± 1% 2.279µ ± 2% +2.70% (p=0.002 n=6) bench/sum_fp16_fp32_4x16_f16c/real_time [256x1x256x1] 2.344µ ± 0% 2.345µ ± 7% ~ (p=0.485 n=6) bench/sum_uint8_int32_4x16_sse41/real_time [256x1x256x1] 4.318µ ± 0% 4.328µ ± 0% +0.22% (p=0.015 n=6) bench/sum_int8_int32_4x16_sse41/real_time [256x1x256x1] 4.319µ ± 0% 4.325µ ± 1% ~ (p=0.394 n=6) bench/sum_fp32_4x4_sse2/real_time [256x1x256x1] 8.790µ ± 0% 8.795µ ± 0% ~ (p=0.394 n=6) bench/sum_uint8_int32_4x16_sse2/real_time [256x1x256x1] 3.966µ ± 0% 3.995µ ± 0% +0.73% (p=0.002 n=6) bench/sum_int8_int32_4x16_sse2/real_time [256x1x256x1] 5.382µ ± 1% 5.410µ ± 1% +0.52% (p=0.041 n=6) bench/sum_uint8_int32_4x16_ssse3/real_time [256x1x256x1] 3.977µ ± 0% 3.994µ ± 1% +0.44% (p=0.004 n=6) bench/sum_int8_int32_4x16_ssse3/real_time [256x1x256x1] 5.373µ ± 0% 5.412µ ± 2% +0.72% (p=0.002 n=6) PiperOrigin-RevId: 821549068 -- e5cb8c0 by Misha Gutman <[email protected]>: Changed K1_1 strategy for f32 to go with single accumulator and maximally long multiple, this significantly improved performance. Since contiguous case tiles became different from discontiguous changed the naming to not include tiles information. bench/sum_fp32_4x16_avx512f/real_time [256x1x256x1] 2.259µ ± 1% bench/sum_fp32_4x8_avx2/real_time [256x1x256x1] 4.339µ ± 0% bench/sum_fp32_4x4_sse2/real_time [256x1x256x1] 8.787µ ± 1% bench/sum_fp32/real_time [256x1x256x1] 3.255µ ± 7% bench/sum_fp32_avx512f/real_time [256x1x256x1] 1.441µ ± 17% bench/sum_fp32_avx2/real_time [256x1x256x1] 1.761µ ± 14% bench/sum_fp32_sse2/real_time [256x1x256x1] 3.435µ ± 13% bench/sum_fp32/real_time [256x1x256x1] 3.261µ ± 13% bench/sum_bf16_fp32_4x32_avx512bf16/real_time [256x1x256x1] 1.722µ ± 1% bench/sum_bf16_fp32_avx512bf16/real_time [256x1x256x1] 1.703µ ± 1% bench/sum_fp16_fp32_4x32_avx512fp16/real_time [256x1x256x1] 1.749µ ± 0% bench/sum_fp16_fp32_avx512fp16/real_time [256x1x256x1] 1.744µ ± 0% bench/sum_fp16_fp32_4x16_f16c/real_time [256x1x256x1] 2.341µ ± 1% bench/sum_fp16_fp32_f16c/real_time [256x1x256x1] 1.652µ ± 7% PiperOrigin-RevId: 821556723 -- aeeca5d by Dillon Sharlet <[email protected]>: Remove threadpool library and just build threadpool.cc as part of subgraph PiperOrigin-RevId: 821566586 -- 7304027 by Dillon Sharlet <[email protected]>: Disable SME when msan is enabled PiperOrigin-RevId: 821694771 -- 89a72e3 by Dillon Sharlet <[email protected]>: Don't bother disabling KleidiAI if using YNNPACK This causes builds to fail, and it's harmless to leave it enabled. PiperOrigin-RevId: 821704594 -- 0c5edfc by Dillon Sharlet <[email protected]>: Disable SME on older Apple compilers PiperOrigin-RevId: 821708108 -- 9b29972 by Dillon Sharlet <[email protected]>: Fix usage of `sv{ld,st}1_hor_vnum_za32` According to the ACLE documentation, this increments *both* the slice and the pointer by `vnum` vectors. This usage of it treated it as if it only incremented the pointer to read from/write to by 1 vector (but did not change the slice). This is interesting because this code worked on QEMU, but fails on real (Apple M4) hardware. I think this indicates there is a bug in the implementation of these instructions in QEMU. PiperOrigin-RevId: 821730217 -- 0d3dc09 by Dillon Sharlet <[email protected]>: Fix correctness of dot benchmarks for transpose_a kernels PiperOrigin-RevId: 821808685 -- 4b73eb1 by Pedro Gonnet <[email protected]>: Update `pthreadpool` dependency. PiperOrigin-RevId: 821857188 -- 66d084b by Dillon Sharlet <[email protected]>: Fix flaky quantize tests PiperOrigin-RevId: 821867761 -- 6fc5696 by Quentin Khan <[email protected]>: Add missing `gemm_config` `.element_size` initializations. PiperOrigin-RevId: 821984759 -- 923b7f9 by Jonathan Clohessy <[email protected]>: Fix build issues and guard against sme2 specific path Signed-off-by: Jonathan Clohessy <[email protected]> -- 06a44d2 by Jonathan Clohessy <[email protected]>: Refactor Convolution to new structure and fix build failures Signed-off-by: Jonathan Clohessy <[email protected]> -- 175903d by Jonathan Clohessy <[email protected]>: Remove unused gemm config structure init Signed-off-by: Jonathan Clohessy <[email protected]> FUTURE_COPYBARA_INTEGRATE_REVIEW=#9005 from JonathanC-ARM:f16_igemm 175903d PiperOrigin-RevId: 821598958
1 parent 53a85b0 commit 214798d

27 files changed

+1088
-115
lines changed

build_srcs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ MICROKERNEL_DEFS = [
284284
"src/x64-transposec/x64-transposec.inc",
285285
"src/x8-pack-lh/x8-pack-lh.inc",
286286
"src/x8-pack-lh/x8-pack-lh-igemm.inc",
287+
"src/x16-pack-lh/x16-pack-lh-igemm.inc",
287288
"src/x8-packq/x8-packq.inc",
288289
"src/x8-packw/x8-packw.inc",
289290
"src/x8-transposec/x8-transposec.inc",

cmake/gen/neonsme2_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ SET(PROD_NEONSME2_MICROKERNEL_SRCS
1414
src/pf16-gemm/pf16-gemm-32x32c2-minmax-neonsme2.c
1515
src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c
1616
src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c
17+
src/pf16-f16-f16-igemm/pf16-f16-f16-igemm-32x32c2-minmax-neonsme2.c
1718
src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme2.c
1819
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-1x32c4-minmax-neonsme2.c
1920
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme2.c
2021
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c
2122
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c
2223
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c
2324
src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c
25+
src/x16-pack-lh/x16-packlh-igemm-neonsme2.c
2426
src/x8-pack-lh/x8-packlh-igemm-neonsme2.c
2527
src/x8-pack-lh/x8-packlh-neonsme2.c
2628
src/x16-pack-lh/x16-packlh-neonsme2.c)

gen/neonsme2_microkernels.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ PROD_NEONSME2_MICROKERNEL_SRCS = [
1010
"src/pf16-gemm/pf16-gemm-32x32c2-minmax-neonsme2.c",
1111
"src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c",
1212
"src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c",
13+
"src/pf16-f16-f16-igemm/pf16-f16-f16-igemm-32x32c2-minmax-neonsme2.c",
1314
"src/pqs8-f32-qc8w-igemm/pqs8-f32-qc8w-igemm-32x32c4-minmax-neonsme2.c",
1415
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-1x32c4-minmax-neonsme2.c",
1516
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32c4-minmax-neonsme2.c",
1617
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x64c4-neonsme2.c",
1718
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-16x64c4-neonsme2.c",
1819
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x64c4-neonsme2.c",
1920
"src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x64c4-neonsme2.c",
21+
"src/x16-pack-lh/x16-packlh-igemm-neonsme2.c",
2022
"src/x8-pack-lh/x8-packlh-igemm-neonsme2.c",
2123
"src/x8-pack-lh/x8-packlh-neonsme2.c",
2224
"src/x16-pack-lh/x16-packlh-neonsme2.c",

include/xnnpack.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,6 +3049,46 @@ enum xnn_status xnn_create_convolution2d_nhwc_f16(
30493049
xnn_weights_cache_t weights_cache,
30503050
xnn_operator_t* convolution_op_out);
30513051

3052+
enum xnn_status xnn_create_convolution2d_nhwc_pf16(
3053+
uint32_t input_padding_top,
3054+
uint32_t input_padding_right,
3055+
uint32_t input_padding_bottom,
3056+
uint32_t input_padding_left,
3057+
uint32_t kernel_height,
3058+
uint32_t kernel_width,
3059+
uint32_t subsampling_height,
3060+
uint32_t subsampling_width,
3061+
uint32_t dilation_height,
3062+
uint32_t dilation_width,
3063+
uint32_t groups,
3064+
size_t group_input_channels,
3065+
size_t group_output_channels,
3066+
size_t input_channel_stride,
3067+
size_t output_channel_stride,
3068+
const void* kernel,
3069+
const void* bias,
3070+
float output_min,
3071+
float output_max,
3072+
uint32_t flags,
3073+
xnn_weights_cache_t weights_cache,
3074+
xnn_operator_t* convolution_op_out);
3075+
3076+
enum xnn_status xnn_reshape_convolution2d_nhwc_pf16(
3077+
xnn_operator_t convolution_op,
3078+
size_t batch_size,
3079+
size_t input_height,
3080+
size_t input_width,
3081+
size_t* workspace_size,
3082+
size_t* output_height_out,
3083+
size_t* output_width_out,
3084+
pthreadpool_t threadpool);
3085+
3086+
enum xnn_status xnn_setup_convolution2d_nhwc_pf16(
3087+
xnn_operator_t convolution_op,
3088+
void* workspace,
3089+
const void* input,
3090+
void* output);
3091+
30523092
enum xnn_status xnn_reshape_convolution2d_nhwc_f16(
30533093
xnn_operator_t convolution_op,
30543094
size_t batch_size,

scripts/generate-tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ tools/generate-gemm-test.py --spec test/qs8-qc4w-gemm-minmax-fp32.yaml --output-
4949
tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc &
5050

5151
### Tests for IGEMM micro-kernels
52+
tools/generate-gemm-test.py --spec test/pf16-f16-igemm-minmax.yaml --output-test test/pf16-f16-igemm-minmax.cc &
5253
tools/generate-gemm-test.py --spec test/f16-igemm-minmax.yaml --output-test test/f16-igemm-minmax.cc &
5354
tools/generate-gemm-test.py --spec test/f16-f32acc-igemm-minmax.yaml --output-test test/f16-f32acc-igemm-minmax.cc &
5455

src/configs/gemm-config.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,17 @@ static void init_pf16_gemm_config(void) {
333333
pf16_gemm_config.arch = xnn_arch_arm_sme2;
334334
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_pf16_gemm_minmax_ukernel_1x32c2__neonsme2);
335335
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_pf16_gemm_minmax_ukernel_32x32c2__neonsme2);
336+
pf16_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(mr)] =
337+
xnn_init_hmp_packed_igemm_ukernel(
338+
(xnn_packed_lhs_igemm_ukernel_fn)
339+
xnn_pf16_f16_igemm_minmax_fp16_ukernel_32x32c2__neonsme2);
336340
pf16_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params;
337341
pf16_gemm_config.pack_weights_and_biases = xnn_pack_kai_f16_weights_and_biases;
338342
pf16_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_f16_weights_and_biases;
343+
pf16_gemm_config.pack_igemm_goki =
344+
(xnn_pack_conv_goki_w_fn)xnn_pack_kai_f16_conv_goki_w_sme2;
345+
pf16_gemm_config.pack_igemm_kgo =
346+
(xnn_pack_conv_kgo_w_fn)xnn_pack_f16_conv_kgo_w;
339347
pf16_gemm_config.mr = mr;
340348
pf16_gemm_config.mr_packed = mr;
341349
pf16_gemm_config.nr = nr;
@@ -5588,6 +5596,7 @@ const struct xnn_gemm_config* xnn_init_pf16_gemm_config() {
55885596
return NULL;
55895597
}
55905598
XNN_INIT_ONCE(pf16_gemm);
5599+
55915600
return pf16_gemm_config.mr ? &pf16_gemm_config : NULL;
55925601
}
55935602

src/configs/pack-lh-config.c

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,27 @@ static struct xnn_pack_lh_config x8_pack_lh_config = {0};
2020
static struct xnn_pack_lh_config x16_pack_lh_config = {0};
2121
static struct xnn_pack_lh_config x32_pack_lh_config = {0};
2222
static struct xnn_pack_lh_config x8_igemm_pack_lh_config = {0};
23+
static struct xnn_pack_lh_config x16_igemm_pack_lh_config = {0};
2324

2425
XNN_INIT_ONCE_GUARD(qp8_pack_lh);
2526
XNN_INIT_ONCE_GUARD(x8_pack_lh);
2627
XNN_INIT_ONCE_GUARD(x16_pack_lh);
2728
XNN_INIT_ONCE_GUARD(x32_pack_lh);
2829
XNN_INIT_ONCE_GUARD(x8_igemm_pack_lh);
30+
XNN_INIT_ONCE_GUARD(x16_igemm_pack_lh);
2931

3032
static void init_qp8_pack_lh_config(void) {
3133
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
32-
qp8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2;
34+
qp8_pack_lh_config.pack_lh_fn =
35+
(xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2;
3336
#else
34-
qp8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__scalar_u1;
37+
qp8_pack_lh_config.pack_lh_fn =
38+
(xnn_pack_lh_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__scalar_u1;
3539
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
36-
qp8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn)xnn_x8_packq_f32qp8_packed_size;
37-
qp8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn)xnn_x8_packq_f32qp8_packed_offset;
40+
qp8_pack_lh_config.size_fn =
41+
(xnn_pack_lh_size_fn)xnn_x8_packq_f32qp8_packed_size;
42+
qp8_pack_lh_config.offset_fn =
43+
(xnn_pack_lh_offset_fn)xnn_x8_packq_f32qp8_packed_offset;
3844
qp8_pack_lh_config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
3945
qp8_pack_lh_config.log2_packed_element_size = 0;
4046
}
@@ -51,13 +57,17 @@ const struct xnn_pack_lh_config* xnn_init_qp8_pack_lh_config() {
5157

5258
static void init_x32_pack_lh_config(void) {
5359
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
54-
#if XNN_ENABLE_ARM_SME2 || XNN_ENABLE_ARM_SME
55-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
60+
#if XNN_ENABLE_ARM_SME2
61+
const struct xnn_hardware_config* hardware_config =
62+
xnn_init_hardware_config();
5663
assert(hardware_config != NULL);
57-
if (hardware_config->arch_flags & xnn_arch_arm_sme) {
58-
x32_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme;
59-
x32_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme;
60-
x32_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x32_pack_lh_offset__neonsme;
64+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
65+
x32_pack_lh_config.pack_lh_fn =
66+
(xnn_pack_lh_ukernel_fn)xnn_x32_pack_lh_ukernel__neonsme;
67+
x32_pack_lh_config.size_fn =
68+
(xnn_pack_lh_size_fn)xnn_x32_pack_lh_size__neonsme;
69+
x32_pack_lh_config.offset_fn =
70+
(xnn_pack_lh_offset_fn)xnn_x32_pack_lh_offset__neonsme;
6171
}
6272
#endif // XNN_ENABLE_ARM_SME2 || XNN_ENABLE_ARM_SME
6373
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -67,7 +77,8 @@ static void init_x32_pack_lh_config(void) {
6777
}
6878

6979
const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
70-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
80+
const struct xnn_hardware_config* hardware_config =
81+
xnn_init_hardware_config();
7182
if (hardware_config == NULL) {
7283
return NULL;
7384
}
@@ -78,12 +89,16 @@ const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
7889
static void init_x16_pack_lh_config(void) {
7990
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
8091
#if XNN_ENABLE_ARM_SME2
81-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
92+
const struct xnn_hardware_config* hardware_config =
93+
xnn_init_hardware_config();
8294
assert(hardware_config != NULL);
83-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
84-
x16_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x16_pack_lh_ukernel__neonsme2;
85-
x16_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x16_pack_lh_size__neonsme2;
86-
x16_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x16_pack_lh_offset__neonsme2;
95+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
96+
x16_pack_lh_config.pack_lh_fn =
97+
(xnn_pack_lh_ukernel_fn)xnn_x16_pack_lh_ukernel__neonsme2;
98+
x16_pack_lh_config.size_fn =
99+
(xnn_pack_lh_size_fn)xnn_x16_pack_lh_size__neonsme2;
100+
x16_pack_lh_config.offset_fn =
101+
(xnn_pack_lh_offset_fn)xnn_x16_pack_lh_offset__neonsme2;
87102
}
88103
#endif // XNN_ENABLE_ARM_SME2
89104
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -93,7 +108,8 @@ static void init_x16_pack_lh_config(void) {
93108
}
94109

95110
const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
96-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
111+
const struct xnn_hardware_config* hardware_config =
112+
xnn_init_hardware_config();
97113
if (hardware_config == NULL) {
98114
return NULL;
99115
}
@@ -104,12 +120,16 @@ const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
104120
static void init_x8_pack_lh_config(void) {
105121
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
106122
#if XNN_ENABLE_ARM_SME2
107-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
123+
const struct xnn_hardware_config* hardware_config =
124+
xnn_init_hardware_config();
108125
assert(hardware_config != NULL);
109-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
110-
x8_pack_lh_config.pack_lh_fn = (xnn_pack_lh_ukernel_fn) xnn_x8_pack_lh_ukernel__neonsme2;
111-
x8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x8_pack_lh_size__neonsme2;
112-
x8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x8_pack_lh_offset__neonsme2;
126+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
127+
x8_pack_lh_config.pack_lh_fn =
128+
(xnn_pack_lh_ukernel_fn)xnn_x8_pack_lh_ukernel__neonsme2;
129+
x8_pack_lh_config.size_fn =
130+
(xnn_pack_lh_size_fn)xnn_x8_pack_lh_size__neonsme2;
131+
x8_pack_lh_config.offset_fn =
132+
(xnn_pack_lh_offset_fn)xnn_x8_pack_lh_offset__neonsme2;
113133
}
114134
#endif // XNN_ENABLE_ARM_SME2
115135
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
@@ -119,7 +139,8 @@ static void init_x8_pack_lh_config(void) {
119139
}
120140

121141
const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
122-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
142+
const struct xnn_hardware_config* hardware_config =
143+
xnn_init_hardware_config();
123144
if (hardware_config == NULL) {
124145
return NULL;
125146
}
@@ -128,17 +149,21 @@ const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
128149
}
129150

130151
static void init_x8_igemm_pack_lh_config(void) {
131-
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
132-
#if XNN_ENABLE_ARM_SME2
133-
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
134-
assert(hardware_config != NULL);
135-
if (hardware_config->arch_flags & xnn_arch_arm_sme2) {
136-
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn = (xnn_pack_lh_igemm_ukernel_fn) xnn_x8_pack_lh_ukernel__igemm_neonsme2;
137-
x8_igemm_pack_lh_config.size_for_igemm_fn = (xnn_pack_lh_igemm_size_fn) xnn_x8_pack_lh_size__igemm_neonsme2;
138-
x8_igemm_pack_lh_config.offset_for_igemm_fn = (xnn_pack_lh_igemm_offset_fn) xnn_x8_pack_lh_offset__igemm_neonsme2;
139-
}
140-
#endif // XNN_ENABLE_ARM_SME2
141-
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
152+
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
153+
#if XNN_ENABLE_ARM_SME2
154+
const struct xnn_hardware_config* hardware_config =
155+
xnn_init_hardware_config();
156+
assert(hardware_config != NULL);
157+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
158+
x8_igemm_pack_lh_config.pack_lh_for_igemm_fn =
159+
(xnn_pack_lh_igemm_ukernel_fn)xnn_x8_pack_lh_ukernel__igemm_neonsme2;
160+
x8_igemm_pack_lh_config.size_for_igemm_fn =
161+
(xnn_pack_lh_igemm_size_fn)xnn_x8_pack_lh_size__igemm_neonsme2;
162+
x8_igemm_pack_lh_config.offset_for_igemm_fn =
163+
(xnn_pack_lh_igemm_offset_fn)xnn_x8_pack_lh_offset__igemm_neonsme2;
164+
}
165+
#endif // XNN_ENABLE_ARM_SME2
166+
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
142167
x8_igemm_pack_lh_config.log2_input_element_size = 0;
143168
x8_igemm_pack_lh_config.log2_packed_element_size = 0;
144169
}
@@ -152,3 +177,33 @@ const struct xnn_pack_lh_config* xnn_init_x8_igemm_pack_lh_config() {
152177
XNN_INIT_ONCE(x8_igemm_pack_lh);
153178
return &x8_igemm_pack_lh_config;
154179
}
180+
181+
static void init_x16_igemm_pack_lh_config(void) {
182+
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
183+
#if XNN_ENABLE_ARM_SME2
184+
const struct xnn_hardware_config* hardware_config =
185+
xnn_init_hardware_config();
186+
assert(hardware_config != NULL);
187+
if ((hardware_config->arch_flags & xnn_arch_arm_sme2)) {
188+
x16_igemm_pack_lh_config.pack_lh_for_igemm_fn =
189+
(xnn_pack_lh_igemm_ukernel_fn)xnn_x16_pack_lh_ukernel__igemm_neonsme2;
190+
x16_igemm_pack_lh_config.size_for_igemm_fn =
191+
(xnn_pack_lh_igemm_size_fn)xnn_x16_pack_lh_size__igemm_neonsme2;
192+
x16_igemm_pack_lh_config.offset_for_igemm_fn =
193+
(xnn_pack_lh_igemm_offset_fn)xnn_x16_pack_lh_offset__igemm_neonsme2;
194+
}
195+
#endif // XNN_ENABLE_ARM_SME2
196+
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
197+
x16_igemm_pack_lh_config.log2_input_element_size = 1;
198+
x16_igemm_pack_lh_config.log2_packed_element_size = 1;
199+
}
200+
201+
const struct xnn_pack_lh_config* xnn_init_x16_igemm_pack_lh_config() {
202+
const struct xnn_hardware_config* hardware_config =
203+
xnn_init_hardware_config();
204+
if (hardware_config == NULL) {
205+
return NULL;
206+
}
207+
XNN_INIT_ONCE(x16_igemm_pack_lh);
208+
return &x16_igemm_pack_lh_config;
209+
}

0 commit comments

Comments
 (0)