Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions bench/qs8-qc4w-gemm-fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,141 @@
#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY


#if XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx)

static void qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_x86_avx512amx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm)
#endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)


#if XNN_ENABLE_AVX512VNNIGFNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnnigfni(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down
12 changes: 12 additions & 0 deletions cmake/gen/avx512amx_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ SET(NON_PROD_AVX512AMX_MICROKERNEL_SRCS
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x16c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x32c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x64c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x16c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x32c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x64c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c
Expand Down
12 changes: 12 additions & 0 deletions gen/avx512amx_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ NON_PROD_AVX512AMX_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x16c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x32c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x64c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x16c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x32c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x64c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c",
Expand Down
15 changes: 15 additions & 0 deletions scripts/generate-qs8-gemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,21 @@ tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=16 -D NR=64 -D DATATYPE=QC4_F32 -D REQUANTIZATION= -o src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=1 -D MR=16 -D NR=64 -D DATATYPE=QC4_F32 -D REQUANTIZATION= -o src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx-prfm.c &

tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=1 -D NR=16 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x16c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=7 -D NR=16 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x16c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=16 -D NR=16 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=1 -D MR=16 -D NR=16 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c &

tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=1 -D NR=32 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x32c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=7 -D NR=32 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x32c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=16 -D NR=32 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=1 -D MR=16 -D NR=32 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c &

tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=1 -D NR=64 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x64c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=7 -D NR=64 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x64c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D MR=16 -D NR=64 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx.c &
tools/xngen src/qs8-gemm/c4-avx512amx.c.in -D VARIANT= -D GFNI=1 -D PREFETCH=1 -D MR=16 -D NR=64 -D DATATYPE=QS8_QC4 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c &

################################## Hexagon HVX #################################
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=1 -D NR=32 -D DATATYPE=QC8 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=2 -D NR=32 -D DATATYPE=QC8 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x32c4-minmax-fp32-hvx.c &
Expand Down
38 changes: 18 additions & 20 deletions src/qs8-gemm/c4-avx512amx.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ $assert NR % 16 == 0
$assert 16 <= NR <= 64
$assert 1 <= MR <= 16
$assert REQUANTIZATION == "FP32" or not REQUANTIZATION
$assert DATATYPE in ["QD8_F32", "QD8_F16", "QC8", "QC4_F32"]
$assert DATATYPE in ["QD8_F32", "QD8_F16", "QC8", "QS8_QC4", "QC4_F32"]
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
Expand All @@ -28,13 +28,11 @@ $if PREFETCH:
#include "src/xnnpack/prefetch.h"


$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8_QC4": "qs8_qc4w", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else ""
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" if REQUANTIZATION else "scalar"
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
$OUT_T = {"QC8": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F32": "float"}[DATATYPE]
$_MM_MAX_EPX8 = "_mm_max_epi8"
$_MM512_CVTXEPI32_EPI8 = "_mm512_cvtsepi32_epi8"
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC4": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
$OUT_T = {"QC8": "int8_t", "QS8_QC4": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F32": "float"}[DATATYPE]
void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c4__avx512amx${"_prfm" if PREFETCH else ""}(
size_t mr,
size_t nc,
Expand Down Expand Up @@ -76,7 +74,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c

XNN_ALIGN(64) struct __tile_config tile_data = {0};
XNN_ALIGN(64) int32_t res[${NR // 16}][${MR} * 16];
$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
XNN_ALIGN(64) int8_t weight_buffer[16 * 64];

kc = round_up_po2(kc, 4 * sizeof(int8_t));
Expand Down Expand Up @@ -126,18 +124,18 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
const __m512 voutput_max = _mm512_set1_ps(params->${PARAMS_STRUCT}.max);
// XNN_FORCE_REALIZATION(voutput_min);
// XNN_FORCE_REALIZATION(voutput_max);
$if DATATYPE in ["QC4_F32", "QC4_F16"]:
const __m512i vmask = _mm512_set1_epi8(0xF0);
const __m512i vshl4 = _mm512_set1_epi64(0x01020408);
XNN_FORCE_REALIZATION(vmask);
XNN_FORCE_REALIZATION(vshl4);
$else:
const __m512 voutput_max_less_zero_point = _mm512_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point);
const __m512i voutput_zero_point = _mm512_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point);
const __m128i voutput_min = _mm_set1_epi8(params->${PARAMS_STRUCT}.output_min);
// XNN_FORCE_REALIZATION(voutput_max_less_zero_point);
// XNN_FORCE_REALIZATION(voutput_zero_point);
// XNN_FORCE_REALIZATION(voutput_min);
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
const __m512i vmask = _mm512_set1_epi8(0xF0);
const __m512i vshl4 = _mm512_set1_epi64(0x01020408);
XNN_FORCE_REALIZATION(vmask);
XNN_FORCE_REALIZATION(vshl4);

do {
$for N in range(0, NR, 16):
Expand All @@ -153,7 +151,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
_tile_loadd(4, a, a_stride);
a += 64;
$for N in range(0, NR, 16):
$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
$for K in range(8):
const __m512i vb${K}x${N//16} = _mm512_load_epi32((const int8_t*) w + ${N * 4 + NR * K * 4});
$for K in range(8):
Expand All @@ -170,7 +168,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
$for P in range(4096, 5120 + NR // 16 * 1024, 64):
xnn_prefetch_to_l1((const int8_t*) w + ${P});

$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
w = (const int8_t*) w + ${NR // 16 * 512};
$else:
w = (const int8_t*) w + ${NR // 16 * 1024};
Expand All @@ -181,7 +179,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
_tile_loadd(6, a, a_stride);
a += kremainder;
$for N in range(0, NR, 16):
$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
for (size_t k = 0; k < ((kremainder + 7) >> 3); ++k) {
const __m512i vb${N//16} = _mm512_load_epi32((const int8_t*) w + ${N * 4} + ${NR * 4} * k);
const __m512i vl${N//16} = _mm512_gf2p8affine_epi64_epi8(vb${N//16}, vshl4, 0);
Expand All @@ -194,7 +192,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
_tile_loadd(7, (const int8_t*) w + ${N * 4}, ${NR * 4});
_tile_dpbssd(${N // 16}, 6, 7);

$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
w = (const int8_t*) w + ((kremainder + 7) >> 3) * ${NR * 4};
$else:
w = (const int8_t*) w + kremainder * ${NR};
Expand Down Expand Up @@ -229,7 +227,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
$for N in range(0, NR, 16):
__m512i vacc${M}x${N//16} = _mm512_add_epi32(vksum${N//16}, _mm512_load_epi32(&res[${N // 16}][0] + ${M * 16}));

$if DATATYPE in ["QC4_F32", "QC4_F16"]:
$if DATATYPE in ["QC4_F32", "QC4_F16", "QS8_QC4"]:
$for M in range(MR):
$for N in range(0, NR, 16):
vacc${M}x${N//16} = _mm512_srai_epi32(vacc${M}x${N//16}, 4);
Expand Down Expand Up @@ -300,7 +298,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
_mm512_mask_storeu_ps(c${M} + ${N}, vmask${N // 16}, vscaled${M}x${N//16});
nc = 0;
}
$elif DATATYPE == "QC8":
$elif DATATYPE in ["QC8", "QS8_QC4"]:
$for N in range(0, NR, 16):
const __m512 vscale${N//16} = _mm512_load_ps((const float*) w + ${N});
w = (const int32_t*) w + ${NR};
Expand All @@ -323,11 +321,11 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c

$for M in range(MR):
$for N in range(0, NR, 16):
__m128i vout${M}x${N//16} = ${_MM512_CVTXEPI32_EPI8}(vacc${M}x${N//16});
__m128i vout${M}x${N//16} = _mm512_cvtsepi32_epi8(vacc${M}x${N//16});

$for M in range(MR):
$for N in range(0, NR, 16):
vout${M}x${N//16} = ${_MM_MAX_EPX8}(vout${M}x${N//16}, voutput_min);
vout${M}x${N//16} = _mm_max_epi8(vout${M}x${N//16}, voutput_min);

if XNN_LIKELY(nc >= ${NR}) {
$for M in reversed(range(MR)):
Expand Down
Loading
Loading