@@ -412,11 +412,11 @@ static void init_pqs8_qc8w_gemm_config(void) {
412412 pqs8_qc8w_gemm_config .init .qs8_qc8w =
413413 xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params ;
414414 pqs8_qc8w_gemm_config .pack_weights_and_biases =
415- xnn_pack_kai_qs8_qc8w_weights_and_biases_sme2 ;
415+ xnn_pack_kai_qs8_qc8w_weights_and_biases_sme ;
416416 pqs8_qc8w_gemm_config .packed_stride_weights_and_biases =
417- xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme2 ;
417+ xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme ;
418418 pqs8_qc8w_gemm_config .pack_igemm_goki =
419- (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme2 ;
419+ (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme ;
420420 pqs8_qc8w_gemm_config .pack_igemm_kgo =
421421 (xnn_pack_conv_kgo_w_fn )xnn_pack_qs8_conv_kgo_w ;
422422 pqs8_qc8w_gemm_config .pack_deconv_goki =
@@ -426,6 +426,39 @@ static void init_pqs8_qc8w_gemm_config(void) {
426426 pqs8_qc8w_gemm_config .nr = nr ;
427427 pqs8_qc8w_gemm_config .log2_kr = 2 ;
428428#endif // XNN_ENABLE_ARM_SME2
429+ } else if (XNN_ENABLE_ARM_SME && (hardware_config -> arch_flags & xnn_arch_arm_sme )) {
430+ #if XNN_ENABLE_ARM_SME
431+ const size_t mr =
432+ xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_mr ();
433+ const size_t nr =
434+ xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme_get_nr ();
435+ pqs8_qc8w_gemm_config .arch = xnn_arch_arm_sme ;
436+ pqs8_qc8w_gemm_config .minmax .gemm [XNN_MR_TO_INDEX (mr )] = XNN_INIT_HMP_GEMM_UKERNEL (xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32c4__neonsme );
437+ pqs8_qc8w_gemm_config .minmax .igemm [XNN_MR_TO_INDEX (mr )] =
438+ xnn_init_hmp_packed_igemm_ukernel (
439+ (xnn_packed_lhs_igemm_ukernel_fn )
440+ xnn_pqs8_qc8w_igemm_minmax_fp32_ukernel_32x32c4__neonsme );
441+ pqs8_qc8w_gemm_config .init .qs8_qc8w =
442+ xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params ;
443+ pqs8_qc8w_gemm_config .pack_weights_and_biases =
444+ xnn_pack_kai_qs8_qc8w_weights_and_biases_sme ;
445+ pqs8_qc8w_gemm_config .packed_stride_weights_and_biases =
446+ xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme ;
447+ pqs8_qc8w_gemm_config .pack_igemm_goki =
448+ (xnn_pack_conv_goki_w_fn )xnn_pack_kai_qs8_conv_goki_w_sme ;
449+ pqs8_qc8w_gemm_config .pack_igemm_kgo =
450+ (xnn_pack_conv_kgo_w_fn )xnn_pack_qs8_conv_kgo_w ;
451+ pqs8_qc8w_gemm_config .pack_deconv_goki =
452+ (xnn_pack_deconv_goki_w_fn )xnn_pack_qs8_deconv_goki_w ;
453+ pqs8_qc8w_gemm_config .mr = mr ;
454+ pqs8_qc8w_gemm_config .mr_packed = mr ;
455+ pqs8_qc8w_gemm_config .nr = nr ;
456+ pqs8_qc8w_gemm_config .log2_kr = 2 ;
457+ #endif // XNN_ENABLE_ARM_SME
458+
459+ }
460+ else {
461+ /* No action */
429462 }
430463 assert (pqs8_qc8w_gemm_config .mr <= XNN_MAX_MR );
431464#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
0 commit comments