@@ -450,11 +450,11 @@ static void init_pqs8_qc8w_gemm_config(void) {
450450 pqs8_qc8w_gemm_config .init .qs8_qc8w =
451451 xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params ;
452452 pqs8_qc8w_gemm_config .pack_weights_and_biases =
453- xnn_pack_kai_qs8_qc8w_weights_and_biases_sme2 ;
453+ xnn_pack_kai_qs8_qc8w_weights_and_biases_sme ;
454454 pqs8_qc8w_gemm_config .packed_stride_weights_and_biases =
455- xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme2 ;
455+ xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme ;
456456 pqs8_qc8w_gemm_config .pack_igemm_goki =
457- (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme2 ;
457+ (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme ;
458458 pqs8_qc8w_gemm_config .pack_igemm_kgo =
459459 (xnn_pack_conv_kgo_w_fn )xnn_pack_qs8_conv_kgo_w ;
460460 pqs8_qc8w_gemm_config .pack_deconv_goki =
@@ -464,6 +464,39 @@ static void init_pqs8_qc8w_gemm_config(void) {
464464 pqs8_qc8w_gemm_config .nr = nr ;
465465 pqs8_qc8w_gemm_config .log2_kr = 2 ;
466466#endif // XNN_ENABLE_ARM_SME2
467+ } else if (XNN_ENABLE_ARM_SME && (hardware_config -> arch_flags & xnn_arch_arm_sme )) {
468+ #if XNN_ENABLE_ARM_SME
469+ const size_t mr =
470+ xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_mr ();
471+ const size_t nr =
472+ xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_nr ();
473+ pqs8_qc8w_gemm_config .arch = xnn_arch_arm_sme ;
474+ pqs8_qc8w_gemm_config .minmax .gemm [XNN_MR_TO_INDEX (mr )] = XNN_INIT_HMP_GEMM_UKERNEL (xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme );
475+ pqs8_qc8w_gemm_config .minmax .igemm [XNN_MR_TO_INDEX (mr )] =
476+ xnn_init_hmp_packed_igemm_ukernel (
477+ (xnn_packed_lhs_igemm_ukernel_fn )
478+ xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme );
479+ pqs8_qc8w_gemm_config .init .qs8_qc8w =
480+ xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params ;
481+ pqs8_qc8w_gemm_config .pack_weights_and_biases =
482+ xnn_pack_kai_qs8_qc8w_weights_and_biases_sme ;
483+ pqs8_qc8w_gemm_config .packed_stride_weights_and_biases =
484+ xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme ;
485+ pqs8_qc8w_gemm_config .pack_igemm_goki =
486+ (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme ;
487+ pqs8_qc8w_gemm_config .pack_igemm_kgo =
488+ (xnn_pack_conv_kgo_w_fn )xnn_pack_qs8_conv_kgo_w ;
489+ pqs8_qc8w_gemm_config .pack_deconv_goki =
490+ (xnn_pack_deconv_goki_w_fn )xnn_pack_qs8_deconv_goki_w ;
491+ pqs8_qc8w_gemm_config .mr = mr ;
492+ pqs8_qc8w_gemm_config .mr_packed = mr ;
493+ pqs8_qc8w_gemm_config .nr = nr ;
494+ pqs8_qc8w_gemm_config .log2_kr = 2 ;
495+ #endif // XNN_ENABLE_ARM_SME
496+
497+ }
498+ else {
499+ /* No action */
467500 }
468501 assert (pqs8_qc8w_gemm_config .mr <= XNN_MAX_MR );
469502#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
0 commit comments