diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..689da6f7e --- /dev/null +++ b/.clang-format @@ -0,0 +1,5 @@ +# Start with a built-in style and modify it +BasedOnStyle: Google + +# Overrides +ColumnLimit: 120 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..2b11178bf --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,14 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..b6d0a3b91 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v17.0.6 # Or pin to your preferred clang-format version + hooks: + - id: clang-format + files: \.(c|h|cpp|hpp|proto|cu|cuh)$ + exclude: ^(apex/contrib/csrc/multihead_attn/cutlass|apex/contrib/csrc/cudnn-frontend)/ + +# TODO: Enable Ruff +# - repo: https://github.com/astral-sh/ruff-pre-commit +# rev: v0.14.0 +# hooks: +# - id: ruff-check +# args: ["--fix"] +# - id: ruff-format +# types_or: [python] +# exclude: "examples" diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp index e80607c1c..c36f17448 100644 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -1,51 +1,60 @@ #include #include // for getcudnnhandle +#include #include #include -#include -#include #include +#include #ifdef DEBUG -#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) +#define DEBUG_MSG(str) \ + do { \ + std::cout << str << std::endl; \ + } while (false) #else -#define DEBUG_MSG(str) do { } while ( false ) +#define DEBUG_MSG(str) \ + do { \ + } while (false) #endif #ifdef DEBUG_CUDNN -#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) +#define DEBUG_CUDNN_MSG(buf, str) \ + do { \ + buf << str << std::endl; \ + } while (false) #else -#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) +#define DEBUG_CUDNN_MSG(buf, str) \ + do { \ + } while (false) #endif -#define checkCudnnErr(...) \ - do { \ - int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - if (err) { \ - return; \ - } \ - } while (0) - +#define checkCudnnErr(...) \ + do { \ + int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + if (err) { \ + return; \ + } \ + } while (0) int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); - return 1; - } - return 0; + if (code) { + printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); + return 1; + } + return 0; } -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); -#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function +void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true); +#define checkCUDAError(val) \ + { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) -{ - if (code != cudaSuccess) - { - const char * errorMessage = cudaGetErrorString(code); - fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); - if (abort){ +void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) { + if (code != cudaSuccess) { + const char* errorMessage = cudaGetErrorString(code); + fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, + errorMessage); + if (abort) { cudaDeviceReset(); exit(code); } @@ -53,352 +62,309 @@ void checkError(cudaError_t code, char const * func, const char *file, const int } void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { - // For INT8x4 and INT8x32 we still compute standard strides here to input - // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. - if (filterFormat == CUDNN_TENSOR_NCHW) { - strideA[nbDims - 1] = 1; - for (int64_t d = nbDims - 2; d >= 0; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - } else { - // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; - strideA[nbDims - 1] = strideA[1] * dimA[1]; - for (int64_t d = nbDims - 2; d >= 2; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - strideA[0] = strideA[2] * dimA[2]; + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + if (filterFormat == CUDNN_TENSOR_NCHW) { + strideA[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + } else { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strideA[1] = 1; + strideA[nbDims - 1] = strideA[1] * dimA[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; } + strideA[0] = strideA[2] * dimA[2]; + } } +int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; } -int getFwdConvDilatedFilterDim(int filterDim, int dilation) { - return ((filterDim - 1) * dilation) + 1; -} - -int getFwdConvPaddedImageDim(int tensorDim, int pad) { - return tensorDim + (2 * pad); -} +int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); } -int getFwdConvOutputDim( - int tensorDim, - int pad, - int filterDim, - int stride, - int dilation) -{ - int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; - return (p); +int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) { + int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; + return (p); } enum { - X_TENSOR, - Y_TENSOR, - W_TENSOR, - Z_TENSOR, - B_TENSOR, - AFTERADD_TENSOR, - AFTERBIAS_TENSOR, - AFTERCONV_TENSOR, - OPTIONAL, - AFTEROPT_TENSOR, + X_TENSOR, + Y_TENSOR, + W_TENSOR, + Z_TENSOR, + B_TENSOR, + AFTERADD_TENSOR, + AFTERBIAS_TENSOR, + AFTERCONV_TENSOR, + OPTIONAL, + AFTEROPT_TENSOR, }; using common_conv_descriptors = std::tuple; - -common_conv_descriptors -create_common_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - cudnnConvolutionMode_t mode) { - const int convDim = 2; - - int64_t strideA_padded[4]; - int64_t outstrideA_padded[4]; - int64_t filterstrideA_padded[4]; - - generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); - - return common_conv_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, strideA_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, outstrideA_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, filterstrideA_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(mode) - .setNDims(convDim) - .setStrides(convDim, convstrideA) - .setPrePadding(convDim, padA) - .setPostPadding(convDim, padA) - .setDilation(convDim, dilationA) - .build()); +common_conv_descriptors create_common_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, + int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, + cudnnDataType_t dataType, cudnnConvolutionMode_t mode) { + const int convDim = 2; + + int64_t strideA_padded[4]; + int64_t outstrideA_padded[4]; + int64_t filterstrideA_padded[4]; + + generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); + + return common_conv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, strideA_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, outstrideA_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, filterstrideA_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(mode) + .setNDims(convDim) + .setStrides(convDim, convstrideA) + .setPrePadding(convDim, padA) + .setPostPadding(convDim, padA) + .setDilation(convDim, dilationA) + .build()); } -using common_convbias_descriptors = std::tuple; - -common_convbias_descriptors -create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = y_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return common_convbias_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('z') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('A') // after add - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('B') // after bias - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('C') // after conv - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build()); +using common_convbias_descriptors = + std::tuple; + +common_convbias_descriptors create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, int64_t* padA, + int64_t* convstrideA, int64_t* dilationA, + int64_t* w_dim_padded, int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = y_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return common_convbias_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('z') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('A') // after add + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('B') // after bias + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('C') // after conv + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build()); } // tensor descriptors used for dgrad enum { - X_OR_DX_TENSOR, - DY_TENSOR, - W_OR_DW_TENSOR, - SCALE_TENSOR, - RELU_TENSOR, - AFTER_DCONV_TENSOR, - AFTER_DRELU_TENSOR, + X_OR_DX_TENSOR, + DY_TENSOR, + W_OR_DW_TENSOR, + SCALE_TENSOR, + RELU_TENSOR, + AFTER_DCONV_TENSOR, + AFTER_DRELU_TENSOR, }; -using dconv_descriptors = std::tuple; - -dconv_descriptors -create_dconv_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return dconv_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build()); +using dconv_descriptors = + std::tuple; + +dconv_descriptors create_dconv_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, + int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return dconv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build()); } // create a cache for plan std::unordered_map plan_cache; // TODO: better name -std::string getConvFusionString(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - cudnnDataType_t dataType, - std::string fusion_string) { - - for(int i=0;i<4;i++) { +std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, + int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) { + for (int i = 0; i < 4; i++) { fusion_string += 'X'; fusion_string += std::to_string(x_dim_padded[i]); } - for(int i=0;i<4;i++) { + for (int i = 0; i < 4; i++) { fusion_string += 'W'; fusion_string += std::to_string(w_dim_padded[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'P'; fusion_string += std::to_string(padA[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'S'; fusion_string += std::to_string(convstrideA[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'D'; fusion_string += std::to_string(dilationA[i]); } @@ -407,742 +373,673 @@ std::string getConvFusionString(int64_t* x_dim_padded, return fusion_string; } -cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, - std::stringstream& log_buf, - cudnn_frontend::OperationGraph& opGraph, - std::string cache_string, - bool use_heuristic = true){ +cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf, + cudnn_frontend::OperationGraph& opGraph, std::string cache_string, + bool use_heuristic = true) { auto it = plan_cache.find(cache_string); if (it != plan_cache.end()) { DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); return it->second; } else { - if (use_heuristic){ + if (use_heuristic) { // TODO: confirm which mode to use auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(opGraph) - .setHeurMode(CUDNN_HEUR_MODE_INSTANT) - .build(); + .setOperationGraph(opGraph) + .setHeurMode(CUDNN_HEUR_MODE_INSTANT) + .build(); // try 3 times for now as WAR for no heuristic training int max_tries = 3, count = 0; auto& engine_configs = heuristics.getEngineConfig(max_tries); - while(true) { + while (true) { try { plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(engine_configs[count], opGraph.getTag()) - .build())); + .setHandle(handle_) + .setEngineConfig(engine_configs[count], opGraph.getTag()) + .build())); break; } catch (cudnn_frontend::cudnnException e) { if (++count == max_tries) throw e; } } - }else{ - DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); - // How many engines support this operation graph ? - auto total_engines = opGraph.getEngineCount(); - DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); - // We have to randomly pick one engine from [0, total_engines) - // Selecting "0" by default - auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); - DEBUG_CUDNN_MSG(log_buf, engine.describe()); - auto& knobs = engine.getSupportedKnobs(); - for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { - DEBUG_CUDNN_MSG(log_buf, it->describe()); - } - if (knobs.begin() != knobs.end()) { - DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); - knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); - DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); - } + } else { + DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); + // How many engines support this operation graph ? + auto total_engines = opGraph.getEngineCount(); + DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); + // We have to randomly pick one engine from [0, total_engines) + // Selecting "0" by default + auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); + DEBUG_CUDNN_MSG(log_buf, engine.describe()); + auto& knobs = engine.getSupportedKnobs(); + for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { + DEBUG_CUDNN_MSG(log_buf, it->describe()); + } + if (knobs.begin() != knobs.end()) { + DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); + knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); + DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); + } - // Createmplacee the requisite engine config - auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); - DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); + // Createmplacee the requisite engine config + auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); + DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); + plan_cache.emplace( + cache_string, + std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); } return plan_cache.find(cache_string)->second; } } -void -run_conv_scale_bias_add_activation(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create a optional add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) +void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, + at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, + at::Half* devPtrB, at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, + w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create a optional add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, + &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(devPtrI ? 6 : 5, data_ptrs) + .setUids(devPtrI ? 6 : 5, uids) .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(devPtrI ? 6 : 5, data_ptrs) - .setUids(devPtrI ? 6 : 5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, + at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, + w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) // TODO: change enum to aftermul + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &add_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } -void -run_conv_scale_bias(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) // TODO: change enum to aftermul - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &scale_op, &add_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, + at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = + create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &act_op, &scale_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_dconv_drelu_dscale(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) +void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, + int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, + cudnnBackendDescriptorType_t mode) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = + create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + // mode should be one of following + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR + auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); + if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { + conv_op_builder.setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &act_op, &scale_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + .setBeta(beta); + } else { + conv_op_builder.setxDesc(std::get(tensors)) + .setdwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta); } -} + auto conv_op = conv_op_builder.build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); -void -run_dconv(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - // mode should be one of following - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR - auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); - if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { - conv_op_builder.setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta); - } - else { - conv_op_builder.setxDesc(std::get(tensors)) - .setdwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta); - } - auto conv_op = conv_op_builder.build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op}; - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op}; + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; - int64_t uids[] = {'x', 'y', 'w'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; + int64_t uids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } -void -run_dconv_add(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the add backward operation - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; - int64_t uids[] = {'x', 'y', 'w', 'r'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, + int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrY, at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = + create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the add backward operation + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - // inputs contains x,w,z,b,(i) std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t dimA[] = {0, 0, 0, 0}; - int64_t filterdimA1[] = {0, 0, 0, 0}; - int64_t filterdimA2[] = {0, 0, 0, 0}; - int64_t filterdimA3[] = {0, 0, 0, 0}; - int64_t filterdimA4[] = {0, 0, 0, 0}; + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w - int axis[] {0,1,2,3}; + int axis[]{0, 1, 2, 3}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[10].size(axis[dim]); } } // output dim in n,c,h,w used by backend - int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below // use these fixed value for test run - int64_t padA[] = {0, 0}; - int64_t padA1[] = {1, 1}; + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; int64_t dilationA[] = {1, 1}; int64_t convstrideA[] = {1, 1}; int64_t convstride1X1[] = {stride_1X1, stride_1X1}; @@ -1151,32 +1048,35 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + outdimA1[dim + 2] = + getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA2[dim + 2] = + getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA3[dim + 2] = + getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view - int64_t outdim1[] = {0, 0, 0, 0}; - int64_t outdim2[] = {0, 0, 0, 0}; - int64_t outdim3[] = {0, 0, 0, 0}; + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; @@ -1190,19 +1090,8 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s auto out1 = at::empty(outdim1, inputs[0].type(), output_format); at::Half* y1 = out1.data_ptr(); - run_conv_scale_bias_add_activation(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - x, - w, - y1, - z, - b, - nullptr); + run_conv_scale_bias_add_activation(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, w, + y1, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); @@ -1212,19 +1101,8 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s auto out2 = at::empty(outdim2, inputs[0].type(), output_format); at::Half* y2 = out2.data_ptr(); - run_conv_scale_bias_add_activation(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); + run_conv_scale_bias_add_activation(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, + y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); // create output of conv3 @@ -1235,26 +1113,13 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s auto identity = at::empty_like(out3); at::Half* yi = identity.data_ptr(); - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ - + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { w = inputs[10].data_ptr(); z = inputs[11].data_ptr(); b = inputs[12].data_ptr(); - run_conv_scale_bias(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - x, - w, - yi, - z, - b); + run_conv_scale_bias(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b); DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); - } - else { + } else { yi = x; } @@ -1262,19 +1127,8 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s z = inputs[6].data_ptr(); b = inputs[9].data_ptr(); - run_conv_scale_bias_add_activation(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - y2, - w, - y3, - z, - b, - yi); + run_conv_scale_bias_add_activation(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, y2, + w, y3, z, b, yi); DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); outputs.push_back(out1); @@ -1285,7 +1139,6 @@ std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, s } std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -1294,40 +1147,40 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t dimA[] = {0, 0, 0, 0}; - int64_t filterdimA1[] = {0, 0, 0, 0}; - int64_t filterdimA2[] = {0, 0, 0, 0}; - int64_t filterdimA3[] = {0, 0, 0, 0}; - int64_t filterdimA4[] = {0, 0, 0, 0}; + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w - int axis[] {0,1,2,3}; + int axis[]{0, 1, 2, 3}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[14].size(axis[dim]); } } // output dim in n,c,h,w used by backend - int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below // use these fixed value for test run - int64_t padA[] = {0, 0}; - int64_t padA1[] = {1, 1}; + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; int64_t dilationA[] = {1, 1}; int64_t convstrideA[] = {1, 1}; int64_t convstride1X1[] = {stride_1X1, stride_1X1}; @@ -1336,32 +1189,35 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + outdimA1[dim + 2] = + getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA2[dim + 2] = + getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA3[dim + 2] = + getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view - int64_t outdim1[] = {0, 0, 0, 0}; - int64_t outdim2[] = {0, 0, 0, 0}; - int64_t outdim3[] = {0, 0, 0, 0}; + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; @@ -1376,16 +1232,7 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // wgrad auto wgrad3 = at::empty_like(inputs[3]); at::Half* dw3 = wgrad3.data_ptr(); - run_dconv(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - conv_in, - dw3, - dy3, + run_dconv(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad @@ -1396,17 +1243,7 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, at::Half* relu2 = inputs[13].data_ptr(); - run_dconv_drelu_dscale(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - dy2, - w, - dy3, - z, + run_dconv_drelu_dscale(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, dy2, w, dy3, z, relu2); DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); @@ -1417,16 +1254,7 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // wgrad auto wgrad2 = at::empty_like(inputs[2]); at::Half* dw2 = wgrad2.data_ptr(); - run_dconv(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, + run_dconv(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad @@ -1437,56 +1265,46 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, at::Half* relu1 = inputs[12].data_ptr(); // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - -/* - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (stride_1X1 != 1){ - // dgrad - run_dconv(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - // mul fused mask - grad_out1.mul_(inputs[15]); - } - else { - at::Half* relu1 = inputs[12].data_ptr(); - // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - } -*/ + run_dconv_drelu_dscale(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, + z, relu1); + + /* + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (stride_1X1 != 1){ + // dgrad + run_dconv(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + // mul fused mask + grad_out1.mul_(inputs[15]); + } + else { + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + } + */ DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); // create grads of conv4 that may exist @@ -1497,20 +1315,11 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // x used for dconv1 and dconv4 wgrad at::Half* x = inputs[0].data_ptr(); - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { w = inputs[14].data_ptr(); at::Half* dy_conv4 = inputs[11].data_ptr(); if (requires_grad) { - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - dx_conv4, - w, - dy_conv4, + run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); @@ -1518,19 +1327,9 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // wgrad wgrad4 = at::empty_like(inputs[14]); at::Half* dw4 = wgrad4.data_ptr(); - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - x, - dw4, - dy_conv4, + run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - } - else { + } else { // if there is no downsample, dx_conv4 is fork of drelu3 dx_conv4 = inputs[11].data_ptr(); } @@ -1539,16 +1338,7 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // wgrad auto wgrad1 = at::empty_like(inputs[1]); at::Half* dw1 = wgrad1.data_ptr(); - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - x, - dw1, - dy1, + run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, dw1, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad @@ -1558,34 +1348,14 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, // backward strided conv cannot be fused // if stride == 1 but channel changes, we can fuse here - if (requires_grad){ - if (stride_1X1 != 1){ - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, + if (requires_grad) { + if (stride_1X1 != 1) { + run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // add 2 together grad_x.add_(grad_x_conv4); - } - else { - run_dconv_add(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, - dx_conv4); + } else { + run_dconv_add(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4); } } @@ -1609,1303 +1379,1193 @@ std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, namespace { enum { - X_TENSOR, - Y_TENSOR, - W_TENSOR, - Z_TENSOR, - B_TENSOR, - AFTERADD_TENSOR, - AFTERBIAS_TENSOR, - AFTERCONV_TENSOR, - OPTIONAL, - AFTEROPT_TENSOR, - AFTERACT_TENSOR, - GEN_INDEX_TENSOR, - MASK_TOP_TENSOR, - MASK_BOTTOM_TENSOR, - MASK_TENSOR, - THRESHOLD_TOP_TENSOR, - THRESHOLD_BOTTOM_TENSOR, + X_TENSOR, + Y_TENSOR, + W_TENSOR, + Z_TENSOR, + B_TENSOR, + AFTERADD_TENSOR, + AFTERBIAS_TENSOR, + AFTERCONV_TENSOR, + OPTIONAL, + AFTEROPT_TENSOR, + AFTERACT_TENSOR, + GEN_INDEX_TENSOR, + MASK_TOP_TENSOR, + MASK_BOTTOM_TENSOR, + MASK_TENSOR, + THRESHOLD_TOP_TENSOR, + THRESHOLD_BOTTOM_TENSOR, }; -using masked_convbias_descriptors = std::tuple; - -masked_convbias_descriptors -create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = y_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - int64_t threshold_stride[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); - - return masked_convbias_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('z') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('A') // after add - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('B') // after bias - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('C') // after conv - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('E') // after act for masked - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('I') // output of the gen index operation - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('m') // top half of the mask created after the less than - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('n') // bottom half of the mask - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('M') // OR of the top and bottom masks - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('t') // threshold for creating the top mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('u') // threshold for creating the bottom mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build()); +using masked_convbias_descriptors = + std::tuple; + +masked_convbias_descriptors create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, + int64_t* convstrideA, int64_t* dilationA, + int64_t* w_dim_padded, int64_t* y_dim_padded, + int64_t* threshold_dim, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = y_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + int64_t threshold_stride[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); + + return masked_convbias_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('z') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('A') // after add + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('B') // after bias + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('C') // after conv + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('E') // after act for masked + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('I') // output of the gen index operation + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('m') // top half of the mask created after the less than + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('n') // bottom half of the mask + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('M') // OR of the top and bottom masks + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('t') // threshold for creating the top mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('u') // threshold for creating the bottom mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build()); } // tensor descriptors used for dgrad enum { - X_OR_DX_TENSOR, - DY_TENSOR, - W_OR_DW_TENSOR, - SCALE_TENSOR, - RELU_TENSOR, - AFTER_DCONV_TENSOR, - AFTER_DRELU_TENSOR, - DGRAD_INPUT_TENSOR, - DGRAD_OPTIONAL_TENSOR, - DGRAD_GEN_INDEX_TENSOR, - DGRAD_MASK_TOP_TENSOR, - DGRAD_MASK_BOTTOM_TENSOR, - DGRAD_MASK_TENSOR, - DGRAD_THRESHOLD_TOP_TENSOR, - DGRAD_THRESHOLD_BOTTOM_TENSOR, + X_OR_DX_TENSOR, + DY_TENSOR, + W_OR_DW_TENSOR, + SCALE_TENSOR, + RELU_TENSOR, + AFTER_DCONV_TENSOR, + AFTER_DRELU_TENSOR, + DGRAD_INPUT_TENSOR, + DGRAD_OPTIONAL_TENSOR, + DGRAD_GEN_INDEX_TENSOR, + DGRAD_MASK_TOP_TENSOR, + DGRAD_MASK_BOTTOM_TENSOR, + DGRAD_MASK_TENSOR, + DGRAD_THRESHOLD_TOP_TENSOR, + DGRAD_THRESHOLD_BOTTOM_TENSOR, }; -using dconv_add_descriptors = std::tuple; - -dconv_add_descriptors -create_dconv_add_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return dconv_add_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build()); +using dconv_add_descriptors = std::tuple; + +dconv_add_descriptors create_dconv_add_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, + int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return dconv_add_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build()); } -using dconv_mask_descriptors = std::tuple; - -dconv_mask_descriptors -create_dconv_mask_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - int64_t threshold_stride[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); - - return dconv_mask_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('I') // output of the gen index operation - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('m') // top half of the mask created after the less than - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('n') // bottom half of the mask - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('M') // OR of the top and bottom masks - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('t') // threshold for creating the top mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('u') // threshold for creating the bottom mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build()); +using dconv_mask_descriptors = + std::tuple; + +dconv_mask_descriptors create_dconv_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, + int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, + int64_t* threshold_dim, cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + int64_t threshold_stride[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); + + return dconv_mask_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('I') // output of the gen index operation + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('m') // top half of the mask created after the less than + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('n') // bottom half of the mask + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('M') // OR of the top and bottom masks + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('t') // threshold for creating the top mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('u') // threshold for creating the bottom mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build()); } -void -run_conv_add_scale_bias_activation(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) +void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, + at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, + at::Half* devPtrB, at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, + w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // create an add node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(add_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(6, data_ptrs) + .setUids(6, uids) .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) +void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, + int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, + int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, + at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, + at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the genIndex descriptor + auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) .setMathPrecision(CUDNN_DATA_FLOAT) + .setAxis(axis) .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) + // Define the lessThan descriptor + auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_LT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // create an add node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(add_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) + // Define the greaterThan descriptor + auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_GT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); + + // Define the logical_or descriptor + auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_LOGICAL_OR) + .setMathPrecision(CUDNN_DATA_BOOLEAN) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); + + // Define the binary_selection descriptor + auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_BINARY_SELECT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create a optional add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Gen_Index Node. + auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(genIndexDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) + DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); + + // Create a LessThan Node. + auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(lessThanDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); + + // Create a GreaterThan Node. + auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(greaterThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); + + // Create a LogicalOr Node. + auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(logicalOrDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); + + // Create a Binary_Selection Node. + auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .settDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(selectionDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + // Create an Operation Graph. In this case it is convolution add bias activation + if (devPtrI) { + std::array ops = { + &conv_op, &scale_op, &bias_op, &add_op, &act_op, + &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op}; + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(8, data_ptrs) + .setUids(8, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } else { + std::array ops = {&conv_op, &scale_op, &bias_op, + &act_op, &genIndex_op, &lessThan_op, + &greaterThan_op, &logicalOr_op, &selection_op}; - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(6, data_ptrs) - .setUids(6, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(7, data_ptrs) + .setUids(7, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } -void -run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI, - int* devPtrT, - int* devPtrU, - int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) +void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, + at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, + at::Half* devPtrR, at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_add_descriptors tensors = + create_dconv_add_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // optional add + auto addDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op, &act_op, &scale_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(6, data_ptrs) + .setUids(6, uids) .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) +void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, + int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, + cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, + at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_mask_descriptors tensors = create_dconv_mask_descriptors(x_dim_padded, pad, convstride, dilation, + w_dim_padded, y_dim_padded, threshold_dim, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the genIndex descriptor + auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) .setMathPrecision(CUDNN_DATA_FLOAT) + .setAxis(axis) .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) + // Define the lessThan descriptor + auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_LT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the genIndex descriptor - auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setMathPrecision(CUDNN_DATA_FLOAT) - .setAxis(axis) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); - - // Define the lessThan descriptor - auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_LT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); + DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); - // Define the greaterThan descriptor - auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_GT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); - - // Define the logical_or descriptor - auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_LOGICAL_OR) - .setMathPrecision(CUDNN_DATA_BOOLEAN) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); - - // Define the binary_selection descriptor - auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_BINARY_SELECT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) + // Define the greaterThan descriptor + auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_GT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); + + // Define the logical_or descriptor + auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_LOGICAL_OR) + .setMathPrecision(CUDNN_DATA_BOOLEAN) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); + + // Define the binary_selection descriptor + auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_BINARY_SELECT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Gen_Index Node. + auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(genIndexDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create a optional add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) + DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); + + // Create a LessThan Node. + auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(lessThanDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Gen_Index Node. - auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(genIndexDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); - - // Create a LessThan Node. - auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(lessThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); - - // Create a GreaterThan Node. - auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(greaterThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); - - // Create a LogicalOr Node. - auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(logicalOrDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); - - // Create a Binary_Selection Node. - auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .settDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(selectionDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - if (devPtrI) { - - std::array ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(8, data_ptrs) - .setUids(8, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } else { - - std::array ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(7, data_ptrs) - .setUids(7, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_dconv_add_drelu_dscale(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_add_descriptors tensors = create_dconv_add_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) + DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); + + // Create a GreaterThan Node. + auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(greaterThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); + + // Create a LogicalOr Node. + auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(logicalOrDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) + DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); + + // Create a Binary_Selection Node. + auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .settDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(selectionDesc) .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op, &act_op, &scale_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(6, data_ptrs) - .setUids(6, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} + DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); -void -run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR, - int* devPtrT, - int* devPtrU, - int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_mask_descriptors tensors = create_dconv_mask_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the genIndex descriptor - auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setMathPrecision(CUDNN_DATA_FLOAT) - .setAxis(axis) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &act_op, &scale_op, &genIndex_op, + &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - // Define the lessThan descriptor - auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_LT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); - // Define the greaterThan descriptor - auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_GT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - // Define the logical_or descriptor - auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_LOGICAL_OR) - .setMathPrecision(CUDNN_DATA_BOOLEAN) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); - - // Define the binary_selection descriptor - auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_BINARY_SELECT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - float alpha = 1.0f; - float beta = 0.0f; + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Gen_Index Node. - auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(genIndexDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); - - // Create a LessThan Node. - auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(lessThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); - - // Create a GreaterThan Node. - auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(greaterThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); - - // Create a LogicalOr Node. - auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(logicalOrDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); - - // Create a Binary_Selection Node. - auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .settDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(selectionDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(7, data_ptrs) - .setUids(7, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(7, data_ptrs) + .setUids(7, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } struct bottleneck_forward_status { - int64_t dimA[4]; int64_t filterdimA1[4]; int64_t filterdimA2[4]; @@ -2919,7 +2579,7 @@ struct bottleneck_forward_status { int64_t outdimA0[4]; int64_t outdimA1[4]; - int64_t outdimA1b[4]; // out1_pad + int64_t outdimA1b[4]; // out1_pad int64_t outdimA2[4]; int64_t outdimA3[4]; int64_t outdimA4[4]; @@ -2931,12 +2591,12 @@ struct bottleneck_forward_status { int64_t convstrideA[2]; int64_t convstride1X1[2]; - int64_t outdim0[4]; // halo input shape + int64_t outdim0[4]; // halo input shape int64_t outdim1[4]; int64_t outdim1b[4]; int64_t outdim2[4]; int64_t outdim3[4]; - int64_t outdim4[4]; // halo output shape + int64_t outdim4[4]; // halo output shape void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; @@ -2960,22 +2620,22 @@ struct bottleneck_forward_status { axis[3] = 3; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { - filterdimA4[dim] = inputs[10].size(axis[dim]); + for (int dim = 0; dim < 4; dim++) { + filterdimA4[dim] = inputs[10].size(axis[dim]); } } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { if (dim == 2) { - filterdimA2hh[dim] = 1; + filterdimA2hh[dim] = 1; } else { - filterdimA2hh[dim] = filterdimA2[dim]; + filterdimA2hh[dim] = filterdimA2[dim]; } } @@ -2988,18 +2648,25 @@ struct bottleneck_forward_status { outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; // use these fixed value for test run - padA[0] = 0; padA[1] = 0; - padA1[0] = 1; padA1[1] = 1; - padA2[0] = 0; padA2[1] = 1; - dilationA[0] = 1; dilationA[1] = 1; - convstrideA[0] = 1; convstrideA[1] = 1; - convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; + padA[0] = 0; + padA[1] = 0; + padA1[0] = 1; + padA1[1] = 1; + padA2[0] = 0; + padA2[1] = 1; + dilationA[0] = 1; + dilationA[1] = 1; + convstrideA[0] = 1; + convstrideA[1] = 1; + convstride1X1[0] = stride_1X1; + convstride1X1[1] = stride_1X1; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + outdimA1[dim + 2] = + getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { @@ -3012,23 +2679,25 @@ struct bottleneck_forward_status { outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA2[dim + 2] = + getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { - outdimA0[dim] = 3; - outdimA4[dim] = 1; + outdimA0[dim] = 3; + outdimA4[dim] = 1; } else { outdimA0[dim] = outdimA1[dim]; - outdimA4[dim] = outdimA2[dim]; + outdimA4[dim] = outdimA2[dim]; } } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA3[dim + 2] = + getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view @@ -3042,7 +2711,7 @@ struct bottleneck_forward_status { axis[2] = 3; axis[3] = 1; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { outdim0[dim] = outdimA0[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]]; outdim1b[dim] = outdimA1b[axis[dim]]; @@ -3055,7 +2724,7 @@ struct bottleneck_forward_status { bottleneck_forward_status forward_state; -} // end of anonymous namespace +} // end of anonymous namespace std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { // NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method. @@ -3066,7 +2735,8 @@ std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1 std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - //printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]); + // printf("outdim1 = + // (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]); auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format); auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format); auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format); @@ -3079,8 +2749,8 @@ std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1 } // inputs contains x,w,z,b,(i) -void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - +void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs) { std::cout << std::fixed; // run @@ -3091,19 +2761,9 @@ void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector(); - run_conv_scale_bias_add_activation(forward_state.dimA, - forward_state.padA, - forward_state.convstride1X1, - forward_state.dilationA, - forward_state.filterdimA1, - forward_state.outdimA1, - CUDNN_DATA_HALF, - x, - w, - y1, - z, - b, - nullptr); + run_conv_scale_bias_add_activation(forward_state.dimA, forward_state.padA, forward_state.convstride1X1, + forward_state.dilationA, forward_state.filterdimA1, forward_state.outdimA1, + CUDNN_DATA_HALF, x, w, y1, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); } @@ -3111,40 +2771,30 @@ void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // run at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); - + at::Half* y1 = fat_halo_y1.data_ptr(); auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); at::Half* y2 = halo_y2.data_ptr(); - run_conv_scale_bias_add_activation(forward_state.outdimA0, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA4, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); + run_conv_scale_bias_add_activation(forward_state.outdimA0, forward_state.padA2, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA4, + CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); return halo_y2; } // compute halo correction term (top or bottom) from slim halo input (N,C,1,W). // slim halo input is 1 pixel wide in H. -at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector inputs, at::Tensor w1by3, at::Tensor out2_part_halo) { - +at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, + std::vector inputs, at::Tensor w1by3, + at::Tensor out2_part_halo) { auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // run @@ -3159,25 +2809,15 @@ at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); at::Half* y2 = halo_y2.data_ptr(); - run_conv_add_scale_bias_activation(forward_state.outdimA4, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2hh, - forward_state.outdimA4, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - prev_out2); + run_conv_add_scale_bias_activation(forward_state.outdimA4, forward_state.padA2, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA2hh, forward_state.outdimA4, + CUDNN_DATA_HALF, y1, w, y2, z, b, prev_out2); return halo_y2; } -void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - +void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs) { std::cout << std::fixed; // from _out1 method @@ -3192,30 +2832,24 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector(); - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation(forward_state.outdimA1, - forward_state.padA1, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); + // printf("forward_state.outdimA1 = + // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + // printf("forward_state.filterdimA2 = + // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + // printf("forward_state.outdimA2 = + // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, + CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } -void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) { - +void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor thresholdTop, + at::Tensor thresholdBottom) { std::cout << std::fixed; // from _out1 method @@ -3230,34 +2864,25 @@ void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vecto auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, - forward_state.padA1, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - forward_state.threshdim, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr, - thresholdTop.data_ptr(), - thresholdBottom.data_ptr(), - 2); // axis == 1 -> Does this assume explicit NHWC? + // printf("forward_state.outdimA1 = + // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + // printf("forward_state.filterdimA2 = + // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + // printf("forward_state.outdimA2 = + // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, + forward_state.threshdim, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr, + thresholdTop.data_ptr(), thresholdBottom.data_ptr(), + 2); // axis == 1 -> Does this assume explicit NHWC? DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } -void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor out1_pad) { - +void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor out1_pad) { std::cout << std::fixed; // from _out1 method @@ -3272,30 +2897,23 @@ void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation(forward_state.outdimA1b, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); + // printf("forward_state.outdimA1 = + // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + // printf("forward_state.filterdimA2 = + // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + // printf("forward_state.outdimA2 = + // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation(forward_state.outdimA1b, forward_state.padA2, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, + CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } -void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - +void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs) { std::cout << std::fixed; // from _out1 method @@ -3311,26 +2929,14 @@ void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector(); z = inputs[11].data_ptr(); b = inputs[12].data_ptr(); - run_conv_scale_bias(forward_state.dimA, - forward_state.padA, - forward_state.convstride1X1, - forward_state.dilationA, - forward_state.filterdimA4, - forward_state.outdimA3, - CUDNN_DATA_HALF, - x, - w, - yi, - z, - b); + run_conv_scale_bias(forward_state.dimA, forward_state.padA, forward_state.convstride1X1, forward_state.dilationA, + forward_state.filterdimA4, forward_state.outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b); DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); - } - else { + } else { yi = x; } @@ -3341,44 +2947,33 @@ void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector(); b = inputs[9].data_ptr(); - run_conv_scale_bias_add_activation(forward_state.outdimA2, - forward_state.padA, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA3, - forward_state.outdimA3, - CUDNN_DATA_HALF, - y2, - w, - y3, - z, - b, - yi); + run_conv_scale_bias_add_activation(forward_state.outdimA2, forward_state.padA, forward_state.convstrideA, + forward_state.dilationA, forward_state.filterdimA3, forward_state.outdimA3, + CUDNN_DATA_HALF, y2, w, y3, z, b, yi); DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); } namespace { struct bottleneck_backward_state { - int64_t dimA[4]; int64_t filterdimA1[4]; int64_t filterdimA2[4]; int64_t filterdimA3[4]; int64_t filterdimA4[4]; - int64_t filterdimA2hh[4]; // Cin,Cout,1,3 + int64_t filterdimA2hh[4]; // Cin,Cout,1,3 int64_t threshdim[4]; int axis[4]; - int64_t outdimA1[4]; // grad_out1 - int64_t outdimA1b[4]; // out1_pad - int64_t outdimA2[4]; // grad_out2 + int64_t outdimA1[4]; // grad_out1 + int64_t outdimA1b[4]; // out1_pad + int64_t outdimA2[4]; // grad_out2 int64_t outdimA3[4]; - int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) - int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3) - int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1) - int64_t outdimA2hh[4]; // input: out1 halo (H=1) + int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) + int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3) + int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1) + int64_t outdimA2hh[4]; // input: out1 halo (H=1) int64_t padA[2]; int64_t padA1[2]; @@ -3387,7 +2982,7 @@ struct bottleneck_backward_state { int64_t convstrideA[2]; int64_t convstride1X1[2]; - int64_t filterdim2hh[4]; // Cin,1,3,Cout + int64_t filterdim2hh[4]; // Cin,1,3,Cout int64_t outdim1[4]; int64_t outdim1b[4]; @@ -3419,19 +3014,19 @@ struct bottleneck_backward_state { axis[3] = 3; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[14].size(axis[dim]); } } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { if (dim == 2) { filterdimA2hh[dim] = 1; } else { @@ -3450,48 +3045,57 @@ struct bottleneck_backward_state { outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0; // use these fixed value for test run - padA[0] = 0; padA[1] = 0; - padA1[0] = 1; padA1[1] = 1; - padA2[0] = 0; padA2[1] = 1; - dilationA[0] = 1; dilationA[1] = 1; - convstrideA[0] = 1; convstrideA[1] = 1; - convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; + padA[0] = 0; + padA[1] = 0; + padA1[0] = 1; + padA1[1] = 1; + padA2[0] = 0; + padA2[1] = 1; + dilationA[0] = 1; + dilationA[1] = 1; + convstrideA[0] = 1; + convstrideA[1] = 1; + convstride1X1[0] = stride_1X1; + convstride1X1[1] = stride_1X1; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + outdimA1[dim + 2] = + getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { - outdimA1b[dim] = outdimA1[dim] + 2; + outdimA1b[dim] = outdimA1[dim] + 2; } else { - outdimA1b[dim] = outdimA1[dim]; + outdimA1b[dim] = outdimA1[dim]; } } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA2[dim + 2] = + getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + outdimA3[dim + 2] = + getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { - outdimA1h[dim] = 3; - outdimA2h[dim] = 3; + outdimA1h[dim] = 3; + outdimA2h[dim] = 3; outdimA1hh[dim] = 1; outdimA2hh[dim] = 1; } else { - outdimA1h[dim] = outdimA1[dim]; - outdimA2h[dim] = outdimA2[dim]; + outdimA1h[dim] = outdimA1[dim]; + outdimA2h[dim] = outdimA2[dim]; outdimA1hh[dim] = outdimA1[dim]; outdimA2hh[dim] = outdimA2[dim]; } @@ -3511,7 +3115,7 @@ struct bottleneck_backward_state { axis[2] = 3; axis[3] = 1; } - for (int dim=0;dim<4;dim++) { + for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim1b[dim] = outdimA1b[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; @@ -3525,10 +3129,9 @@ struct bottleneck_backward_state { bottleneck_backward_state backward_state; -} +} // namespace std::vector bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - std::cout << std::fixed; backward_state.init(explicit_nhwc, stride_1X1, inputs); @@ -3554,8 +3157,8 @@ std::vector bottleneck_backward_init(bool explicit_nhwc, int stride_ return outputs; } -void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - +void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs) { // dconv3+drelu2+dscale2 at::Half* conv_in = inputs[13].data_ptr(); at::Half* dy3 = inputs[10].data_ptr(); @@ -3563,23 +3166,14 @@ void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector< // wgrad auto wgrad3 = outputs[3]; at::Half* dw3 = wgrad3.data_ptr(); - run_dconv(backward_state.outdimA2, - backward_state.padA, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA3, - backward_state.outdimA3, - CUDNN_DATA_HALF, - conv_in, - dw3, - dy3, + run_dconv(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA, backward_state.dilationA, + backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); - } -at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - +at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3599,18 +3193,9 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std at::Half* relu2 = inputs[13].data_ptr(); - run_dconv_drelu_dscale(backward_state.outdimA2, - backward_state.padA, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA3, - backward_state.outdimA3, - CUDNN_DATA_HALF, - dy2, - w, - dy3, - z, - relu2); + run_dconv_drelu_dscale(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA, + backward_state.dilationA, backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF, + dy2, w, dy3, z, relu2); // do halo exchange of dy2 here @@ -3619,8 +3204,8 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std return grad_out2; } -at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { - +at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3636,28 +3221,21 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std at::Half* z = inputs[4].data_ptr(); at::Half* relu1 = inputs[12].data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); + // printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); // fused dgrad - //printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]); - run_dconv_drelu_dscale(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); + // printf("backward_state.outdim1 = + // {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]); + run_dconv_drelu_dscale(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, + backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF, + dy1, w, dy2, z, relu1); return grad_out1; } -at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) { - +at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2, + at::Tensor thresholdTop, at::Tensor thresholdBottom) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3673,32 +3251,23 @@ at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1 at::Half* z = inputs[4].data_ptr(); at::Half* relu1 = inputs[12].data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); + // printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); // fused dgrad - run_dconv_drelu_dscale_mask(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - backward_state.threshdim, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1, - thresholdTop.data_ptr(), - thresholdBottom.data_ptr(), - 2); + run_dconv_drelu_dscale_mask(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, + backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, + backward_state.threshdim, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1, + thresholdTop.data_ptr(), thresholdBottom.data_ptr(), 2); return grad_out1; } -// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C] -at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector inputs, at::Tensor w1by3, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { - +// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) +// to produce output of shape [N,1,W,C] +at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector inputs, + at::Tensor w1by3, std::vector outputs, + at::Tensor grad_out2_halo, at::Tensor relu1_halo, + at::Tensor part_grad_out1) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3710,37 +3279,34 @@ at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int strid // dgrad auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format); at::Half* dy1h = grad_out1_halo.data_ptr(); - //at::Half* w = inputs[2].data_ptr(); // use w1by3 instead, which is a sliced version of inputs[2] + // at::Half* w = inputs[2].data_ptr(); // use w1by3 instead, which is a sliced version of inputs[2] at::Half* w = w1by3.data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* relu1h = relu1_halo.data_ptr(); at::Half* pdy1h = part_grad_out1.data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); - // fused dgrad - //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); - //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); - //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); + // printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); + // fused dgrad + // printf("backward_state.outdimA1h = + // {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); + // printf("backward_state.outdimA2h = + // {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); + // printf("backward_state.filterdimA2 = + // {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); run_dconv_add_drelu_dscale(backward_state.outdimA1hh, - backward_state.padA2, // 0,1 - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2hh, // C,1,3,C - backward_state.outdimA2hh, - CUDNN_DATA_HALF, - dy1h, - w, - dy2h, - z, - relu1h, - pdy1h); + backward_state.padA2, // 0,1 + backward_state.convstrideA, backward_state.dilationA, + backward_state.filterdimA2hh, // C,1,3,C + backward_state.outdimA2hh, CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h, pdy1h); return grad_out1_halo; } -// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] -at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { - +// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) +// to produce output of shape [N,3,W,C] +at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2_halo, + at::Tensor relu1_halo) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3756,29 +3322,23 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1 at::Half* z = inputs[4].data_ptr(); at::Half* relu1h = relu1_halo.data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); - // fused dgrad - //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); - //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); - //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); - run_dconv_drelu_dscale(backward_state.outdimA1h, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2h, - CUDNN_DATA_HALF, - dy1h, - w, - dy2h, - z, - relu1h); + // printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); + // fused dgrad + // printf("backward_state.outdimA1h = + // {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); + // printf("backward_state.outdimA2h = + // {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); + // printf("backward_state.filterdimA2 = + // {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); + run_dconv_drelu_dscale(backward_state.outdimA1h, backward_state.padA1, backward_state.convstrideA, + backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2h, + CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h); return grad_out1_halo; } -void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2) { - +void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor input, at::Tensor grad_out2) { std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; @@ -3792,24 +3352,20 @@ void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vec auto wgrad2 = outputs[2]; at::Half* dw2 = wgrad2.data_ptr(); - //printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]); - //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); - run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos) - backward_state.padA2, // 0, 1 - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, // dw2.shape - backward_state.outdimA2, // dy2.shape - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + // printf("outdimA1b = + // (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]); + // printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); + run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos) + backward_state.padA2, // 0, 1 + backward_state.convstrideA, backward_state.dilationA, + backward_state.filterdimA2, // dw2.shape + backward_state.outdimA2, // dy2.shape + CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); } -void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { - +void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3824,27 +3380,21 @@ void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector< // wgrad auto wgrad2 = outputs[2]; at::Half* dw2 = wgrad2.data_ptr(); - - //printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); - run_dconv(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, + + // printf("outdimA1 = + // (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); + run_dconv(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA, + backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); } -// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C] -// input and grad_out2_halo tensors are all of same shape -// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout] -at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2_halo) { - +// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension +// [N,1,W,C] input and grad_out2_halo tensors are all of same shape output tensor is of shape [Cin,1,3,Cout] (regular +// filter dims are [Cin,3,3,Cout] +at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor input, + at::Tensor grad_out2_halo) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3860,28 +3410,27 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format); at::Half* dw2 = wgrad2_halo.data_ptr(); - //printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]); - //printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]); - //printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]); - //printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]); - //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); + // printf("backward_state.outdimA1hh = + // {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]); + // printf("backward_state.outdimA2hh = + // {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]); + // printf("backward_state.filterdim2hh = + // {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]); + // printf("backward_state.filterdimA2hh = + // {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]); + // printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); run_dconv(backward_state.outdimA1hh, // N,C,1,W - backward_state.padA2, // 0, 1 - backward_state.convstrideA, - backward_state.dilationA, + backward_state.padA2, // 0, 1 + backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2hh, // Cin,Cout,1,3 - backward_state.outdimA2hh, // N,C,1,W - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + backward_state.outdimA2hh, // N,C,1,W + CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); return wgrad2_halo; } -void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out1) { - +void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out1) { at::Half* x = inputs[0].data_ptr(); at::Half* dy1 = grad_out1.data_ptr(); @@ -3889,22 +3438,13 @@ void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector< // wgrad auto wgrad1 = outputs[1]; at::Half* dw1 = wgrad1.data_ptr(); - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA1, - backward_state.outdimA1, - CUDNN_DATA_HALF, - x, - dw1, - dy1, + run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, + backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, x, dw1, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - } -void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { - +void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; @@ -3914,43 +3454,43 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector(); at::Half* dy1 = grad_out1.data_ptr(); -/* - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (stride_1X1 != 1){ - // dgrad - run_dconv(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - // mul fused mask - grad_out1.mul_(inputs[15]); - } - else { - at::Half* relu1 = inputs[12].data_ptr(); - // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - } -*/ + /* + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (stride_1X1 != 1){ + // dgrad + run_dconv(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + // mul fused mask + grad_out1.mul_(inputs[15]); + } + else { + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + } + */ DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); // create grads of conv4 that may exist @@ -3963,20 +3503,12 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector(); at::Half* dy_conv4 = inputs[11].data_ptr(); if (requires_grad) { - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA4, - backward_state.outdimA3, - CUDNN_DATA_HALF, - dx_conv4, - w, - dy_conv4, + run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, + backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); @@ -3984,19 +3516,10 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector(); - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA4, - backward_state.outdimA3, - CUDNN_DATA_HALF, - x, - dw4, - dy_conv4, + run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, + backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - } - else { + } else { // if there is no downsample, dx_conv4 is fork of drelu3 dx_conv4 = inputs[11].data_ptr(); } @@ -4008,34 +3531,16 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector()); m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward", py::call_guard()); m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward", py::call_guard()); + m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward", + py::call_guard()); + m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward", + py::call_guard()); + m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward", + py::call_guard()); + m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward", + py::call_guard()); m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward", py::call_guard()); - m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init", py::call_guard()); - m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward", py::call_guard()); - m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward", py::call_guard()); - m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward", py::call_guard()); - m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward", py::call_guard()); - m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward", py::call_guard()); - m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward", py::call_guard()); - m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward", py::call_guard()); - m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward", py::call_guard()); - m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward", py::call_guard()); - m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward", py::call_guard()); - m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward", py::call_guard()); + m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init", + py::call_guard()); + m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward", + py::call_guard()); + m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward", + py::call_guard()); + m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward", + py::call_guard()); + m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward", + py::call_guard()); + m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward", + py::call_guard()); + m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward", + py::call_guard()); + m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward", + py::call_guard()); + m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward", + py::call_guard()); + m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward", + py::call_guard()); + m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward", + py::call_guard()); + m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward", + py::call_guard()); } diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp index 5fbe53793..b0d0b4442 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp @@ -1,54 +1,66 @@ #include #include // for getcudnnhandle +#include #include #include -#include -#include #include +#include #ifdef DEBUG -#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) +#define DEBUG_MSG(str) \ + do { \ + std::cout << str << std::endl; \ + } while (false) #else -#define DEBUG_MSG(str) do { } while ( false ) +#define DEBUG_MSG(str) \ + do { \ + } while (false) #endif #ifdef DEBUG_CUDNN -#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) +#define DEBUG_CUDNN_MSG(buf, str) \ + do { \ + buf << str << std::endl; \ + } while (false) #else -#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) +#define DEBUG_CUDNN_MSG(buf, str) \ + do { \ + } while (false) #endif #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -#define checkCudnnErr(...) \ - do { \ - int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - if (err) { \ - return; \ - } \ - } while (0) - +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define checkCudnnErr(...) \ + do { \ + int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + if (err) { \ + return; \ + } \ + } while (0) int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); - return 1; - } - return 0; + if (code) { + printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); + return 1; + } + return 0; } -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); -#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function +void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true); +#define checkCUDAError(val) \ + { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) { - if (code != cudaSuccess) - { - const char * errorMessage = cudaGetErrorString(code); - fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); - if (abort){ +void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) { + if (code != cudaSuccess) { + const char* errorMessage = cudaGetErrorString(code); + fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, + errorMessage); + if (abort) { cudaDeviceReset(); exit(code); } @@ -56,74 +68,55 @@ void checkError(cudaError_t code, char const * func, const char *file, const int } void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { - // For INT8x4 and INT8x32 we still compute standard strides here to input - // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. - if (filterFormat == CUDNN_TENSOR_NCHW) { - strideA[nbDims - 1] = 1; - for (int64_t d = nbDims - 2; d >= 0; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - } else { - // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; - strideA[nbDims - 1] = strideA[1] * dimA[1]; - for (int64_t d = nbDims - 2; d >= 2; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - strideA[0] = strideA[2] * dimA[2]; + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + if (filterFormat == CUDNN_TENSOR_NCHW) { + strideA[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; } + } else { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strideA[1] = 1; + strideA[nbDims - 1] = strideA[1] * dimA[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + strideA[0] = strideA[2] * dimA[2]; + } } +int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; } -int getFwdConvDilatedFilterDim(int filterDim, int dilation) { - return ((filterDim - 1) * dilation) + 1; -} - - -int getFwdConvPaddedImageDim(int tensorDim, int pad) { - return tensorDim + (2 * pad); -} - +int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); } -int getFwdConvOutputDim(int tensorDim, - int pad, - int filterDim, - int stride, - int dilation) { - int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; - return (p); +int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) { + int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; + return (p); } - // create a cache for plan std::unordered_map plan_cache; - -std::string getConvFusionString(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - cudnnDataType_t dataType, - std::string fusion_string) { - - for(int i=0;i<4;i++) { +std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, + int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) { + for (int i = 0; i < 4; i++) { fusion_string += 'X'; fusion_string += std::to_string(x_dim_padded[i]); } - for(int i=0;i<4;i++) { + for (int i = 0; i < 4; i++) { fusion_string += 'W'; fusion_string += std::to_string(w_dim_padded[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'P'; fusion_string += std::to_string(padA[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'S'; fusion_string += std::to_string(convstrideA[i]); } - for(int i=0;i<2;i++) { + for (int i = 0; i < 2; i++) { fusion_string += 'D'; fusion_string += std::to_string(dilationA[i]); } @@ -132,12 +125,9 @@ std::string getConvFusionString(int64_t* x_dim_padded, return fusion_string; } - -cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, - std::stringstream& log_buf, - cudnn_frontend::OperationGraph& opGraph, - std::string cache_string, - bool use_heuristic = true){ +cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf, + cudnn_frontend::OperationGraph& opGraph, std::string cache_string, + bool use_heuristic = true) { auto it = plan_cache.find(cache_string); if (it != plan_cache.end()) { DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); @@ -147,17 +137,17 @@ cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, if (use_heuristic) { // TODO: confirm which mode to use auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(opGraph) - .setHeurMode(CUDNN_HEUR_MODE_INSTANT) - .build(); + .setOperationGraph(opGraph) + .setHeurMode(CUDNN_HEUR_MODE_INSTANT) + .build(); auto engine_config_count = heuristics.getEngineConfigCount(); auto& engine_configs = heuristics.getEngineConfig(engine_config_count); for (int64_t count = 0; count < engine_config_count; count++) { try { plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(engine_configs[count], opGraph.getTag()) - .build())); + .setHandle(handle_) + .setEngineConfig(engine_configs[count], opGraph.getTag()) + .build())); break; } catch (cudnn_frontend::cudnnException e) { // Throw exception if all engines failed @@ -189,1462 +179,1334 @@ cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, // Createmplacee the requisite engine config auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); + plan_cache.emplace( + cache_string, + std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); } return plan_cache.find(cache_string)->second; } } - -void -run_conv_bias(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* convstride, - int64_t* dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, conv_pad) - .setPostPadding(convDim, conv_pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) +void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* convstride, + int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, + at::Half* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int convDim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterConvTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('c') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto bTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterBiasTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, conv_pad) + .setPostPadding(convDim, conv_pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(xTensor) + .setwDesc(wTensor) + .setyDesc(afterConvTensor) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(bTensor) + .setyDesc(afterBiasTensor) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Operation Graph. In this case it is convolution bias activation + std::array ops = {&conv_op, &bias_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(2, ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; + int64_t uids[] = {'x', 'w', 'b', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(2, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, + int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrB, int8_t* devPtrM, at::Half* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int conv_dim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto mTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('m') + .setAlignment(16) + .setDataType(CUDNN_DATA_INT8) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterConvTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('c') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto bTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterBiasTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('B') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterMaskTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('M') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterReLUTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(conv_dim) + .setStrides(conv_dim, conv_stride) + .setPrePadding(conv_dim, conv_pad) + .setPostPadding(conv_dim, conv_pad) + .setDilation(conv_dim, conv_dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Define the mask operation + auto maskDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(xTensor) + .setwDesc(wTensor) + .setyDesc(afterConvTensor) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Bias Node + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(bTensor) + .setyDesc(afterBiasTensor) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // create a Mask Node + auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(mTensor) + .setyDesc(afterMaskTensor) + .setpwDesc(maskDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, mask_op.describe()); + + // Create an Activation Node + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(mask_op.getOutputTensor()) + .setyDesc(afterReLUTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution bias activation + std::array ops = {&conv_op, &bias_op, &mask_op, &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(4, ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY}; + int64_t uids[] = {'x', 'w', 'b', 'm', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_conv_bias_mask_relu(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - int8_t* devPtrM, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto mTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('m') - .setAlignment(16) - .setDataType(CUDNN_DATA_INT8) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('B') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterMaskTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('M') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterReLUTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the mask operation - auto maskDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // create a Mask Node - auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(mTensor) - .setyDesc(afterMaskTensor) - .setpwDesc(maskDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, mask_op.describe()); - - // Create an Activation Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(mask_op.getOutputTensor()) - .setyDesc(afterReLUTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op, &mask_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(4, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'm', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, + int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrS, at::Half* devPtrB, at::Half* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int conv_dim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t s_dim[] = {1, y_dim[1], 1, 1}; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterConvTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('c') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); + + generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto sTensor = cudnn_frontend::TensorBuilder() + .setDim(4, s_dim) + .setStrides(4, stride) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, sTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterScaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('S') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterScaleTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto bTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterBiasTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('B') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterReLUTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(conv_dim) + .setStrides(conv_dim, conv_stride) + .setPrePadding(conv_dim, conv_pad) + .setPostPadding(conv_dim, conv_pad) + .setDilation(conv_dim, conv_dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the scale operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(xTensor) + .setwDesc(wTensor) + .setyDesc(afterConvTensor) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(sTensor) + .setyDesc(afterScaleTensor) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(bTensor) + .setyDesc(afterBiasTensor) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setyDesc(afterReLUTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution bias activation + std::array ops = {&conv_op, &scale_op, &bias_op, &act_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrS, devPtrB, devPtrY}; + int64_t uids[] = {'x', 'w', 's', 'b', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_conv_cscale_cbias_relu(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrS, - at::Half* devPtrB, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t s_dim[] = {1, y_dim[1], 1, 1}; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto sTensor = cudnn_frontend::TensorBuilder() - .setDim(4, s_dim) - .setStrides(4, stride) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, sTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterScaleTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('S') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterScaleTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('B') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterReLUTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the scale operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(sTensor) - .setyDesc(afterScaleTensor) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setyDesc(afterReLUTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &scale_op, &bias_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrS, devPtrB, devPtrY}; - int64_t uids[] = {'x', 'w', 's', 'b', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, + int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrB, at::Half* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int conv_dim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterConvTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('c') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto bTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterBiasTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('B') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto afterReLUTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(conv_dim) + .setStrides(conv_dim, conv_stride) + .setPrePadding(conv_dim, conv_pad) + .setPostPadding(conv_dim, conv_pad) + .setDilation(conv_dim, conv_dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the bias operation + auto biasDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(xTensor) + .setwDesc(wTensor) + .setyDesc(afterConvTensor) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(bTensor) + .setyDesc(afterBiasTensor) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setyDesc(afterReLUTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution bias activation + std::array ops = {&conv_op, &bias_op, &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(3, ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; + int64_t uids[] = {'x', 'w', 'b', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_conv_bias_relu(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('B') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterReLUTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) +void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, + at::Half* devPtrS, at::Half* devPtrDX) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int convDim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t s_dim[] = {1, dy_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto dyTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); + + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto rTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); + + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto inActGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('R') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); + + generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto scaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, s_dim) + .setStrides(4, stride) + .setId('s') + .setAlignment(16) + .setDataType(dataType) .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setyDesc(afterReLUTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(3, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + DEBUG_CUDNN_MSG(log_buf, scaleTensor.describe()); + + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto dxTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, dxTensor.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the bias backward operation + auto scaleDesc = + cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Create an relu backward Node + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(dyTensor) + .setxDesc(rTensor) + .setdxDesc(inActGradTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create bias node + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(inActGradTensor) + .setbDesc(scaleTensor) + .setyDesc(dxTensor) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is bias only + std::array ops = {&act_op, &scale_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + // creating unique dummy values + int64_t pad_dummy[] = {40, 40}; + int64_t stride_dummy[] = {40, 40}; + int64_t dilation_dummy[] = {40, 40}; + auto cache_string = + getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, s_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrDY, devPtrR, devPtrS, devPtrDX}; + int64_t uids[] = {'y', 'r', 's', 'x'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_drelu_dscale(int64_t* dy_dim, - cudnnDataType_t dataType, - at::Half* devPtrDY, - at::Half* devPtrR, - at::Half* devPtrS, - at::Half* devPtrDX) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t s_dim[] = {1, dy_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto dyTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto rTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inActGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('R') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); - - generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto scaleTensor = cudnn_frontend::TensorBuilder() - .setDim(4, s_dim) - .setStrides(4, stride) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto dxTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, dxTensor.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the bias backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Create an relu backward Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(dyTensor) - .setxDesc(rTensor) - .setdxDesc(inActGradTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create bias node - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(inActGradTensor) - .setbDesc(scaleTensor) - .setyDesc(dxTensor) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&act_op, &scale_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - // creating unique dummy values - int64_t pad_dummy[] = {40, 40}; - int64_t stride_dummy[] = {40, 40}; - int64_t dilation_dummy[] = {40, 40}; - auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, s_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrDY, devPtrR, devPtrS, devPtrDX}; - int64_t uids[] = {'y', 'r', 's', 'x'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, + at::Half* devPtrDR, float* devPtrDB) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int convDim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t b_dim[] = {1, dy_dim[1], 1, 1}; + + // Creates the necessary tensor descriptors + int64_t stride[4]; + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto dyTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); + + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto rTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); + + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto inActGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dy_dim) + .setStrides(4, stride) + .setId('R') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto biasGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the bias backward operation + auto biasDesc = cudnn_frontend::ReductionDescBuilder() + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Create an relu backward Node + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(dyTensor) + .setxDesc(rTensor) + .setdxDesc(inActGradTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create bias node + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(inActGradTensor) + .setyDesc(biasGradTensor) + .setreductionDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Operation Graph. In this case it is bias only + std::array ops = {&act_op, &bias_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + // creating unique dummy values + int64_t pad_dummy[] = {20, 20}; + int64_t stride_dummy[] = {20, 20}; + int64_t dilation_dummy[] = {20, 20}; + auto cache_string = + getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB}; + int64_t uids[] = {'x', 'r', 'R', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_drelu_dbias(int64_t* dy_dim, - cudnnDataType_t dataType, - at::Half* devPtrDY, - at::Half* devPtrR, - at::Half* devPtrDR, - float* devPtrDB) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, dy_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto dyTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto rTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inActGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('R') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto biasGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create an relu backward Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(dyTensor) - .setxDesc(rTensor) - .setdxDesc(inActGradTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(inActGradTensor) - .setyDesc(biasGradTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&act_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - // creating unique dummy values - int64_t pad_dummy[] = {20, 20}; - int64_t stride_dummy[] = {20, 20}; - int64_t dilation_dummy[] = {20, 20}; - auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB}; - int64_t uids[] = {'x', 'r', 'R', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* pad, int64_t* convstride, + int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrR, at::Half* devPtrRg, float* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + float alpha = 1.0f; + float beta = 0.0f; + int64_t b_dim[] = {1, x_dim[1], 1, 1}; + + int64_t stride[4]; + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto outConvGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto inConvGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('A') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual() + .build(); + DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe()); + + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto rTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); + + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto inReLUGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('R') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto inBiasGradTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the bias backward operation + auto biasDesc = cudnn_frontend::ReductionDescBuilder() + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdyDesc(outConvGradTensor) + .setwDesc(wTensor) + .setdxDesc(inConvGradTensor) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create an relu backward Node + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(inConvGradTensor) + .setxDesc(rTensor) + .setdxDesc(inReLUGradTensor) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create bias node + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(inReLUGradTensor) + .setyDesc(inBiasGradTensor) + .setreductionDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Operation Graph. In this case it is bias only + std::array ops = {&conv_op, &act_op, &bias_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY}; + int64_t uids[] = {'x', 'w', 'r', 'R', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } - -void -run_dconv_drelu_dbias(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrR, - at::Half* devPtrRg, - float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, x_dim[1], 1, 1}; - - int64_t stride[4]; - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto outConvGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inConvGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('A') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto rTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inReLUGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('R') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inBiasGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdyDesc(outConvGradTensor) - .setwDesc(wTensor) - .setdxDesc(inConvGradTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create an relu backward Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(inConvGradTensor) - .setxDesc(rTensor) - .setdxDesc(inReLUGradTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(inReLUGradTensor) - .setyDesc(inBiasGradTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&conv_op, &act_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY}; - int64_t uids[] = {'x', 'w', 'r', 'R', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; +void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, + int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, + at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + + try { + int conv_dim = 2; + float alpha = 1.0f; + float beta = 0.0f; + + // Define the convolution problem + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto wTensor = cudnn_frontend::TensorBuilder() + .setDim(4, w_dim) + .setStrides(4, stride) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); + + generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto yTensor = cudnn_frontend::TensorBuilder() + .setDim(4, y_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(conv_dim) + .setStrides(conv_dim, conv_stride) + .setPrePadding(conv_dim, conv_pad) + .setPostPadding(conv_dim, conv_pad) + .setDilation(conv_dim, conv_dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Create a convolution node + // mode should be one of following + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR + auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); + if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { + conv_op_builder.setdxDesc(xTensor).setwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc); + } else { + conv_op_builder.setxDesc(xTensor).setdwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc); } + auto conv_op = conv_op_builder.setAlpha(alpha).setBeta(beta).build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); -} + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op}; + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); -void -run_dconv(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - cudnnBackendDescriptorType_t mode) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - - // Define the convolution problem - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto yTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); - - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Create a convolution node - // mode should be one of following - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR - auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); - if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { - conv_op_builder.setdxDesc(xTensor) - .setwDesc(wTensor) - .setdyDesc(yTensor) - .setcDesc(convDesc); - } - else { - conv_op_builder.setxDesc(xTensor) - .setdwDesc(wTensor) - .setdyDesc(yTensor) - .setcDesc(convDesc); - } - auto conv_op = conv_op_builder - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrY}; - int64_t uids[] = {'x', 'w', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} + // Create string encoding for plan caching + auto cache_string = + getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); -void -run_dbias(int64_t* x_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - int64_t b_dim[] = {1, x_dim[1], 1, 1}; - - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto yTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(xTensor) - .setyDesc(yTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - int64_t pad_dummy[] = {10, 10}; - int64_t stride_dummy[] = {10, 10}; - int64_t dilation_dummy[] = {10, 10}; - auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY}; - int64_t uids[] = {'x', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(2, data_ptrs) - .setUids(2, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrW, devPtrY}; + int64_t uids[] = {'x', 'w', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } } +void run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPtrX, float* devPtrY) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + int64_t b_dim[] = {1, x_dim[1], 1, 1}; + + int64_t stride[4]; + generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto xTensor = cudnn_frontend::TensorBuilder() + .setDim(4, x_dim) + .setStrides(4, stride) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(); + DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); + + generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); + auto yTensor = cudnn_frontend::TensorBuilder() + .setDim(4, b_dim) + .setStrides(4, stride) + .setId('y') + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); + + // Define the bias backward operation + auto biasDesc = cudnn_frontend::ReductionDescBuilder() + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // Create bias node + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(xTensor) + .setyDesc(yTensor) + .setreductionDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Operation Graph. In this case it is bias only + std::array ops = {&bias_op}; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); + + // Create string encoding for plan caching + int64_t pad_dummy[] = {10, 10}; + int64_t stride_dummy[] = {10, 10}; + int64_t dilation_dummy[] = {10, 10}; + auto cache_string = + getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY}; + int64_t uids[] = {'x', 'y'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(2, data_ptrs) + .setUids(2, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} std::vector conv_bias_mask_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; @@ -1654,8 +1516,8 @@ std::vector conv_bias_mask_relu_forward(std::vector inpu auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -1665,18 +1527,19 @@ std::vector conv_bias_mask_relu_forward(std::vector inpu } // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); + y_dim[dim + 2] = + getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run @@ -1687,18 +1550,7 @@ std::vector conv_bias_mask_relu_forward(std::vector inpu auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); - run_conv_bias_mask_relu(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - m, - y); + run_conv_bias_mask_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, m, y); DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item()); @@ -1707,13 +1559,12 @@ std::vector conv_bias_mask_relu_forward(std::vector inpu return outputs; } - at::Tensor conv_cscale_cbias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -1723,18 +1574,19 @@ at::Tensor conv_cscale_cbias_relu_forward(std::vector inputs, int64_ } // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); + y_dim[dim + 2] = + getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run @@ -1745,26 +1597,15 @@ at::Tensor conv_cscale_cbias_relu_forward(std::vector inputs, int64_ auto out = at::empty(y_dim, inputs[0].type(), at::MemoryFormat::ChannelsLast); at::Half* y = out.data_ptr(); - run_conv_cscale_cbias_relu(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - s, - b, - y); + run_conv_cscale_cbias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, s, b, y); DEBUG_MSG("[DEBUG] conv-cscale-cbias-relu : " << y.to(at::kFloat).sum().item()); return out; } - -std::vector conv_cscale_cbias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { +std::vector conv_cscale_cbias_relu_backward(std::vector inputs, int64_t padding, + int64_t stride) { bool requires_grad = inputs[0].requires_grad(); for (int i = 0; i <= 4; i++) { @@ -1778,9 +1619,9 @@ std::vector conv_cscale_cbias_relu_backward(std::vector auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -1790,11 +1631,11 @@ std::vector conv_cscale_cbias_relu_backward(std::vector y_dim[dim] = inputs[3].size(axis[dim]); } - int64_t b_dim[] = {1, y_dim[1], 1, 1}; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // run // drelu-dbias @@ -1804,45 +1645,23 @@ std::vector conv_cscale_cbias_relu_backward(std::vector auto dscale = at::empty_like(inputs[4]); at::Half* ds = dscale.data_ptr(); - auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); - run_drelu_dscale(y_dim, - CUDNN_DATA_HALF, - dy, - r, - s, - ds); + auto options = + at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); + run_drelu_dscale(y_dim, CUDNN_DATA_HALF, dy, r, s, ds); // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); at::Half* dw = wgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - dw, - ds, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, ds, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - dx, - w, - ds, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, ds, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); @@ -1850,7 +1669,6 @@ std::vector conv_cscale_cbias_relu_backward(std::vector return outputs; } - std::vector conv_bias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; @@ -1859,8 +1677,8 @@ std::vector conv_bias_relu_forward(std::vector inputs, i auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -1870,18 +1688,19 @@ std::vector conv_bias_relu_forward(std::vector inputs, i } // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); + y_dim[dim + 2] = + getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run @@ -1891,17 +1710,7 @@ std::vector conv_bias_relu_forward(std::vector inputs, i auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); - run_conv_bias_relu(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - y); + run_conv_bias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y); DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item()); @@ -1910,7 +1719,6 @@ std::vector conv_bias_relu_forward(std::vector inputs, i return outputs; } - std::vector conv_bias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { bool requires_grad = inputs[0].requires_grad(); @@ -1925,9 +1733,9 @@ std::vector conv_bias_relu_backward(std::vector inputs, auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -1936,12 +1744,12 @@ std::vector conv_bias_relu_backward(std::vector inputs, w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[3].size(axis[dim]); } - - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // run // drelu-dbias @@ -1949,54 +1757,31 @@ std::vector conv_bias_relu_backward(std::vector inputs, at::Half* r = inputs[2].data_ptr(); auto drelu = at::empty_like(inputs[2]); at::Half* dr = drelu.data_ptr(); - auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); + auto options = + at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); auto bgrad = at::empty(b_dim, options, output_format); float* db = bgrad.data_ptr(); - run_drelu_dbias(y_dim, - CUDNN_DATA_HALF, - dy, - r, - dr, - db); + run_drelu_dbias(y_dim, CUDNN_DATA_HALF, dy, r, dr, db); // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - dw, - dr, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + at::Half* dw = wgrad.data_ptr(); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dr, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - dx, - w, - dr, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dr, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); outputs.push_back(bgrad); return outputs; - } std::vector conv_bias_forward(std::vector inputs, int64_t padding, int64_t stride) { @@ -2007,8 +1792,8 @@ std::vector conv_bias_forward(std::vector inputs, int64_ auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -2018,18 +1803,19 @@ std::vector conv_bias_forward(std::vector inputs, int64_ } // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); + y_dim[dim + 2] = + getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run @@ -2039,17 +1825,7 @@ std::vector conv_bias_forward(std::vector inputs, int64_ auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); - run_conv_bias(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - y); + run_conv_bias(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y); DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item()); @@ -2058,7 +1834,6 @@ std::vector conv_bias_forward(std::vector inputs, int64_ return outputs; } - std::vector conv_bias_backward(std::vector inputs, int64_t padding, int64_t stride) { bool requires_grad = inputs[0].requires_grad(); @@ -2073,9 +1848,9 @@ std::vector conv_bias_backward(std::vector inputs, int64 auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - int64_t y_dim[] = {0, 0, 0, 0}; + int64_t x_dim[] = {0, 0, 0, 0}; + int64_t w_dim[] = {0, 0, 0, 0}; + int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; @@ -2084,55 +1859,35 @@ std::vector conv_bias_backward(std::vector inputs, int64 w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[2].size(axis[dim]); } - - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; + int64_t b_dim[] = {1, y_dim[1], 1, 1}; + + int64_t conv_pad[] = {padding, padding}; + int64_t conv_stride[] = {stride, stride}; + int64_t conv_dilation[] = {1, 1}; // run // dbias at::Half* dy = inputs[2].data_ptr(); - auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); + auto options = + at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); auto bgrad = at::empty(b_dim, options, output_format); float* db = bgrad.data_ptr(); - run_dbias(y_dim, - CUDNN_DATA_HALF, - dy, - db); - + run_dbias(y_dim, CUDNN_DATA_HALF, dy, db); + // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - dw, - dy, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + at::Half* dw = wgrad.data_ptr(); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dy, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - dx, - w, - dy, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dy, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); @@ -2143,11 +1898,14 @@ std::vector conv_bias_backward(std::vector inputs, int64 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward", py::call_guard()); - m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward", py::call_guard()); + m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward", + py::call_guard()); m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward", py::call_guard()); m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward", py::call_guard()); - m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward", py::call_guard()); - m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU", py::call_guard()); - m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, "Fused Conv-(const)Scale-(const)Bias-ReLU backward", py::call_guard()); + m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward", + py::call_guard()); + m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU", + py::call_guard()); + m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, + "Fused Conv-(const)Scale-(const)Bias-ReLU backward", py::call_guard()); } - diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index 0aad4cfc3..858d429d0 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -1,9 +1,9 @@ #include #include #include -#include #include +#include #include "norm_sample.h" @@ -13,28 +13,19 @@ enum bn_type { BN_FWD, BN_BWD }; // this is a global variable static std::map, cudnn_frontend::ExecutionPlan> gbn_plan_cache; -at::Tensor gbn_forward(const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const float momentum, - const float epsilon, - const int64_t bn_group, - const int rank_id, - const std::vector &peer_buffers) { - +at::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const float momentum, const float epsilon, + const int64_t bn_group, const int rank_id, const std::vector& peer_buffers) { int64_t N = x.size(0); int64_t C = x.size(1); int64_t H = x.size(2); int64_t W = x.size(3); - int64_t tensorDims[] = {N, C, H, W}; - int64_t peerDims[] = {bn_group, 4*C, 1, 1}; + int64_t tensorDims[] = {N, C, H, W}; + int64_t peerDims[] = {bn_group, 4 * C, 1, 1}; int64_t perChannelDims[] = {1, C, 1, 1}; - int64_t epsilonDims[] = {1, 1, 1, 1}; + int64_t epsilonDims[] = {1, 1, 1, 1}; // Allocate output tensor at::Tensor y = at::empty_like(x); @@ -46,7 +37,7 @@ at::Tensor gbn_forward(const at::Tensor& x, // we need the peer size for the buffer reset size_t peer_size = 1; - for (size_t i = 0; i < 4; ++i){ + for (size_t i = 0; i < 4; ++i) { peer_size *= peerDims[i]; } @@ -55,7 +46,7 @@ at::Tensor gbn_forward(const at::Tensor& x, // check if plan already exists std::vector fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; - if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { + if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) { auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); gbn_plan_cache.emplace(fv, std::move(plan)); } @@ -64,46 +55,28 @@ at::Tensor gbn_forward(const at::Tensor& x, auto plan = gbn_plan_cache.find(fv)->second; // execute - execute_batch_norm_forward(plan, - x.data_ptr(), - y.data_ptr(), - scale.data_ptr(), - bias.data_ptr(), - running_mean.data_ptr(), - running_var.data_ptr(), - running_mean.data_ptr(), - running_var.data_ptr(), - minibatch_mean.data_ptr(), - minibatch_inv_var.data_ptr(), - void_peer_buffers, - static_cast(epsilon), - static_cast(momentum), - peer_size, - rank_id); - + execute_batch_norm_forward(plan, x.data_ptr(), y.data_ptr(), scale.data_ptr(), bias.data_ptr(), + running_mean.data_ptr(), running_var.data_ptr(), running_mean.data_ptr(), + running_var.data_ptr(), minibatch_mean.data_ptr(), minibatch_inv_var.data_ptr(), + void_peer_buffers, static_cast(epsilon), static_cast(momentum), peer_size, + rank_id); + return y; } -std::vector gbn_backward( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const float epsilon, - const int64_t bn_group, - const int rank_id, - const std::vector &peer_buffers) { - +std::vector gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + const float epsilon, const int64_t bn_group, const int rank_id, + const std::vector& peer_buffers) { int64_t N = x.size(0); int64_t C = x.size(1); int64_t H = x.size(2); int64_t W = x.size(3); - int64_t tensorDims[] = {N, C, H, W}; - int64_t peerDims[] = {bn_group, 4*C, 1, 1}; + int64_t tensorDims[] = {N, C, H, W}; + int64_t peerDims[] = {bn_group, 4 * C, 1, 1}; int64_t perChannelDims[] = {1, C, 1, 1}; - int64_t epsilonDims[] = {1, 1, 1, 1}; + int64_t epsilonDims[] = {1, 1, 1, 1}; // Allocate output tensor // outputs @@ -121,42 +94,29 @@ std::vector gbn_backward( // we need the peer size for the buffer reset size_t peer_size = 1; - for (size_t i = 0; i < 4; ++i){ + for (size_t i = 0; i < 4; ++i) { peer_size *= peerDims[i]; } - + assert(bn_group == void_peer_buffers.size()); std::vector fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; - if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) { + if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) { auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); gbn_plan_cache.emplace(fv, std::move(plan)); } - + // get plan and handle auto plan = gbn_plan_cache.find(fv)->second; - + // execute - execute_batch_norm_backward(plan, - x.data_ptr(), - dy.data_ptr(), - scale.data_ptr(), - minibatch_mean.data_ptr(), - minibatch_inv_var.data_ptr(), - void_peer_buffers, - x_grad.data_ptr(), - scale_grad.data_ptr(), - bias_grad.data_ptr(), - static_cast(epsilon), - peer_size, - rank_id); + execute_batch_norm_backward(plan, x.data_ptr(), dy.data_ptr(), scale.data_ptr(), minibatch_mean.data_ptr(), + minibatch_inv_var.data_ptr(), void_peer_buffers, x_grad.data_ptr(), scale_grad.data_ptr(), + bias_grad.data_ptr(), static_cast(epsilon), peer_size, rank_id); return std::vector{x_grad, scale_grad, bias_grad}; } - - - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &gbn_forward, "Group batch norm forward", py::call_guard()); m.def("backward", &gbn_backward, "Group batch backward", py::call_guard()); diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp index e14502109..7d871256e 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp @@ -1,56 +1,57 @@ /* -* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a -* copy of this software and associated documentation files (the "Software"), -* to deal in the Software without restriction, including without limitation -* the rights to use, copy, modify, merge, publish, distribute, sublicense, -* and/or sell copies of the Software, and to permit persons to whom the -* Software is furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -* DEALINGS IN THE SOFTWARE. -*/ + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ #include "norm_sample.h" -#include -#include "cudnn_backend.h" + #include // for getcudnnhandle +#include #include #include +#include "cudnn_backend.h" + // some helpers -int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int)code, cudaGetErrorString(code), expr); - return 1; - } - return 0; +int64_t checkCudaError(cudaError_t code, const char *expr, const char *file, int line) { + if (code) { + printf("CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int)code, cudaGetErrorString(code), expr); + return 1; + } + return 0; } -int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); - return 1; - } - return 0; +int64_t checkCudnnError(cudnnStatus_t code, const char *expr, const char *file, int line) { + if (code) { + printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); + return 1; + } + return 0; } -bool -AllowAll(cudnnBackendDescriptor_t engine_config) { +bool AllowAll(cudnnBackendDescriptor_t engine_config) { (void)engine_config; return false; } -void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) { +void generateStrides(const int64_t *dimA, int64_t *strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) { // For INT8x4 and INT8x32 we still compute standard strides here to input // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. if (filterFormat == CUDNN_TENSOR_NCHW) { @@ -60,7 +61,7 @@ void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudn } } else { // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; + strideA[1] = 1; strideA[nbDims - 1] = strideA[1] * dimA[1]; for (int64_t d = nbDims - 2; d >= 2; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; @@ -69,14 +70,9 @@ void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudn } } - // runtime -cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *peerDims, - cudnnDataType_t data_type) { - +cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, int64_t *perChannelSum, int64_t *epsilon, + int64_t *peerDims, cudnnDataType_t data_type) { // get the cudnn handle cudnnHandle_t handle = torch::native::getCudnnHandle(); @@ -89,68 +85,64 @@ cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC); generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); - auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { + auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) + .setDim(4, tensorDims) .setStrides(4, tensor_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; - auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, - int64_t id) { + auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() - .setDim(4, tensorDims) + .setDim(4, tensorDims) .setStrides(4, peer_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; - generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() - .setDim(4, perChannelSum) + .setDim(4, perChannelSum) .setStrides(4, stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; - auto xTensor = tensor_create(data_type, 100); - auto yTensor = tensor_create(data_type, 101); - auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); - auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); - auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); - auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105); - auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); - auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); - auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108); - auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109); - - + auto xTensor = tensor_create(data_type, 100); + auto yTensor = tensor_create(data_type, 101); + auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); + auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); + auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); + auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105); + auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); + auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); + auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108); + auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109); + int64_t epsilon_stride[4]; generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() - .setDim(4, epsilon) + .setDim(4, epsilon) .setStrides(4, epsilon_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .setByValue(true) - .build(); + .setId(id) + .setAlignment(16) + .setDataType(type) + .setByValue(true) + .build(); }; - auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110); - auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111); + auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110); + auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111); // Create the two peer stat tensors. Jump IDs in case we need to add more tensors with UIDs std::vector peerStatTensors; @@ -165,119 +157,111 @@ cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, // Forward training cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; - //Create a Finalize node + // Create a Finalize node auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor) - .setNextRunningMeanAndVar(outMeanTensor, outVarTensor) - .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) - .setEpsilonTensor(epsilonTensor) - .setExpDecayFactorTensor(expDecayTensor) - .setPeerStatTensor(peerStatTensors) - .setyDesc(yTensor) - .build(); - - std::array ops = {&batch_norm_op}; + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(phase) + .setxDesc(xTensor) + .setScaleAndBias(scaleTensor, biasTensor) + .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor) + .setNextRunningMeanAndVar(outMeanTensor, outVarTensor) + .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) + .setEpsilonTensor(epsilonTensor) + .setExpDecayFactorTensor(expDecayTensor) + .setPeerStatTensor(peerStatTensors) + .setyDesc(yTensor) + .build(); + + std::array ops = {&batch_norm_op}; #else - std::array ops = {}; + std::array ops = {}; #endif - auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); - //std::cout << opGraph.describe() << std::endl; + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); + // std::cout << opGraph.describe() << std::endl; cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = - cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" - , "heuristics_fallback" - }, opGraph,::AllowAll, filtered_configs, true); - - //std::cout << "get_heuristics_list Statuses: "; - //for (auto i = 0u ; i < statuses.size(); i++) { - // std::cout << cudnn_frontend::to_string(statuses[i]) << " "; - //} - //std::cout << std::endl; - //std::cout << "Filter config list has " << filtered_configs.size() << " configurations " << std::endl; + auto statuses = cudnn_frontend::get_heuristics_list<2>({"heuristics_instant", "heuristics_fallback"}, opGraph, + ::AllowAll, filtered_configs, true); + + // std::cout << "get_heuristics_list Statuses: "; + // for (auto i = 0u ; i < statuses.size(); i++) { + // std::cout << cudnn_frontend::to_string(statuses[i]) << " "; + // } + // std::cout << std::endl; + // std::cout << "Filter config list has " << filtered_configs.size() << " configurations " << std::endl; // some verbose printing: - //std::cout << "Tensor shape: (" << tensorDims[0] << ", " << tensorDims[1] << ", " << tensorDims[2] << ", " << tensorDims[3] << ")" << std::endl; - + // std::cout << "Tensor shape: (" << tensorDims[0] << ", " << tensorDims[1] << ", " << tensorDims[2] << ", " << + // tensorDims[3] << ")" << std::endl; + auto plan_builder = [&filtered_configs, &opGraph, &handle]() { for (auto i = 0u; i < filtered_configs.size(); i++) { try { - auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[i], opGraph.getTag()) + .build(); return plan; } catch (cudnn_frontend::cudnnException &e) { continue; } } - return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); + return cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); }; assert(filtered_configs.size() > 0); auto plan = plan_builder(); return plan; - } -void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, - void *xDevPtr, - void *yDevPtr, - void *scaledevPtr, - void *biasdevPtr, - void *in_meandevPtr, - void *in_vardevPtr, - void *out_meandevPtr, - void *out_vardevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - double exponential_decay_factor, - size_t peer_size, - int rank_id) { - +void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void *xDevPtr, void *yDevPtr, void *scaledevPtr, + void *biasdevPtr, void *in_meandevPtr, void *in_vardevPtr, void *out_meandevPtr, + void *out_vardevPtr, void *saved_meandevPtr, void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, double epsilon_val, + double exponential_decay_factor, size_t peer_size, int rank_id) { // get handle cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - - // get stream + + // get stream cudaStream_t stream; cudnnGetStream(handle_, &stream); - + try { // allocate workspace auto workspace_size = plan.getWorkspaceSize(); - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - void* workPtr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + void *workPtr = nullptr; if (workspace_size > 0) { workPtr = workspace_tensor.data_ptr(); } - + // first the data pointers - std::vector data_ptrs {xDevPtr, yDevPtr, scaledevPtr, biasdevPtr, - in_meandevPtr, in_vardevPtr, out_meandevPtr, out_vardevPtr, - saved_meandevPtr, saved_inv_vardevPtr, - &epsilon_val, &exponential_decay_factor}; + std::vector data_ptrs{ + xDevPtr, yDevPtr, scaledevPtr, biasdevPtr, in_meandevPtr, in_vardevPtr, + out_meandevPtr, out_vardevPtr, saved_meandevPtr, saved_inv_vardevPtr, &epsilon_val, &exponential_decay_factor}; data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); // then the uids std::vector uids; for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { uids.push_back(i); } - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workPtr) - .setDataPointers(data_ptrs.size(), data_ptrs.data()) - .setUids(uids.size(), uids.data()) - .build(); - //std::cout << "variantPack " << variantPack.describe() << std::endl; - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workPtr) + .setDataPointers(data_ptrs.size(), data_ptrs.data()) + .setUids(uids.size(), uids.data()) + .build(); + // std::cout << "variantPack " << variantPack.describe() << std::endl; + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); // Reset local communication buffer - cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream); - + cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream); + } catch (cudnn_frontend::cudnnException &e) { struct cudaDeviceProp prop; checkCudaErr(cudaGetDeviceProperties(&prop, 0)); @@ -288,12 +272,8 @@ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, } } -cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *peerDims, - cudnnDataType_t data_type) { - +cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, int64_t *perChannelSum, int64_t *epsilon, + int64_t *peerDims, cudnnDataType_t data_type) { // get cudnn handle cudnnHandle_t handle = torch::native::getCudnnHandle(); @@ -307,60 +287,60 @@ cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) { - return cudnn_frontend::TensorBuilder() + return cudnn_frontend::TensorBuilder() .setDim(4, tensorDims) - .setStrides(4, tensor_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setStrides(4, tensor_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; auto peer_tensor_create = [&peer_stride, &peerDims](cudnnDataType_t type, int64_t id) { - return cudnn_frontend::TensorBuilder() + return cudnn_frontend::TensorBuilder() .setDim(4, peerDims) - .setStrides(4, peer_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setStrides(4, peer_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { - return cudnn_frontend::TensorBuilder() + return cudnn_frontend::TensorBuilder() .setDim(4, perChannelSum) - .setStrides(4, stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .build(); + .setStrides(4, stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .build(); }; - auto xTensor = tensor_create(data_type, 100); - auto dyTensor = tensor_create(data_type, 101); - auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); - auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); - auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); - auto dxTensor = tensor_create(data_type, 105); - auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); - auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); + auto xTensor = tensor_create(data_type, 100); + auto dyTensor = tensor_create(data_type, 101); + auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); + auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); + auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); + auto dxTensor = tensor_create(data_type, 105); + auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); + auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); int64_t epsilon_stride[4]; generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { - return cudnn_frontend::TensorBuilder() + return cudnn_frontend::TensorBuilder() .setDim(4, epsilon) - .setStrides(4, epsilon_stride) - .setId(id) - .setAlignment(16) - .setDataType(type) - .setByValue(true) - .build(); + .setStrides(4, epsilon_stride) + .setId(id) + .setAlignment(16) + .setDataType(type) + .setByValue(true) + .build(); }; - auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108); + auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108); std::vector peerStatTensors; for (size_t i = 109; i < 109 + peerDims[0]; ++i) { @@ -371,43 +351,48 @@ cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, // Batch normalization cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; - //Create a Finalize node + // Create a Finalize node auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setxDesc(xTensor) - .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) - .setdyDesc(dyTensor) - .setScale(scaleTensor) - .setEpsilonTensor(epsilonTensor) - .setDScaleAndDBias(dScaleTensor, dBiasTensor) - .setdxDesc(dxTensor) - .setPeerStatTensor(peerStatTensors) - .build(); - - std::array ops = {&batch_norm_op}; + .setNormalizationMode(normalizationMode) + .setxDesc(xTensor) + .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) + .setdyDesc(dyTensor) + .setScale(scaleTensor) + .setEpsilonTensor(epsilonTensor) + .setDScaleAndDBias(dScaleTensor, dBiasTensor) + .setdxDesc(dxTensor) + .setPeerStatTensor(peerStatTensors) + .build(); + + std::array ops = {&batch_norm_op}; #else - std::array ops = {}; + std::array ops = {}; #endif - - auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); - //std::cout << opGraph.describe() << std::endl; + + auto opGraph = + cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); + // std::cout << opGraph.describe() << std::endl; cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = - cudnn_frontend::get_heuristics_list<2>({"heuristics_instant" - , "heuristics_fallback" - }, opGraph,::AllowAll, filtered_configs, true); - + auto statuses = cudnn_frontend::get_heuristics_list<2>({"heuristics_instant", "heuristics_fallback"}, opGraph, + ::AllowAll, filtered_configs, true); + auto plan_builder = [&filtered_configs, &opGraph, &handle]() { for (auto i = 0u; i < filtered_configs.size(); i++) { try { - auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build(); + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[i], opGraph.getTag()) + .build(); return plan; } catch (cudnn_frontend::cudnnException &e) { continue; } } - return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build(); + return cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); }; assert(filtered_configs.size() > 0); @@ -416,58 +401,47 @@ cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, return plan; } -void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, - void *xDevPtr, - void *dyDevPtr, - void *scaledevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - void *dxDevPtr, - void *dscaledevPtr, - void *dbiasdevPtr, - double epsilon_val, - size_t peer_size, - int rank_id) { - +void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void *xDevPtr, void *dyDevPtr, void *scaledevPtr, + void *saved_meandevPtr, void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, void *dxDevPtr, void *dscaledevPtr, + void *dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id) { // get handle cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - + // get stream cudaStream_t stream; cudnnGetStream(handle_, &stream); - + try { // allocate workspace auto workspace_size = plan.getWorkspaceSize(); - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - void* workPtr = nullptr; + auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + void *workPtr = nullptr; if (workspace_size > 0) { workPtr = workspace_tensor.data_ptr(); } - + // create helper arrays - std::vector data_ptrs {xDevPtr, dyDevPtr, scaledevPtr, - saved_meandevPtr, saved_inv_vardevPtr, - dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val}; + std::vector data_ptrs{xDevPtr, dyDevPtr, scaledevPtr, saved_meandevPtr, saved_inv_vardevPtr, + dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val}; data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); std::vector uids; for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { uids.push_back(i); } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workPtr) - .setDataPointers(data_ptrs.size(), data_ptrs.data()) - .setUids(uids.size(), uids.data()) - .build(); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workPtr) + .setDataPointers(data_ptrs.size(), data_ptrs.data()) + .setUids(uids.size(), uids.data()) + .build(); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); // Reset local communication buffer - cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream); - + cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream); + } catch (cudnn_frontend::cudnnException &e) { struct cudaDeviceProp prop; checkCudaErr(cudaGetDeviceProperties(&prop, 0)); diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.h b/apex/contrib/csrc/cudnn_gbn/norm_sample.h index 0706416b5..3a0ebb3b0 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.h +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.h @@ -24,53 +24,51 @@ #pragma once -#include +#include +#include +#include +#include #include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include /* some helpers */ -void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat); +void generateStrides(const int64_t *dimA, int64_t *strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat); -int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line); -int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line); +int64_t checkCudaError(cudaError_t code, const char *expr, const char *file, int line); +int64_t checkCudnnError(cudnnStatus_t code, const char *expr, const char *file, int line); #define checkCudaErr(...) \ - do { \ - int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - assert(err == 0); \ - } while (0) + do { \ + int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + assert(err == 0); \ + } while (0) #define checkCudnnErr(...) \ - do { \ - int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - assert(err == 0); \ - } while (0) + do { \ + int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + assert(err == 0); \ + } while (0) /** * @brief Run a Group BN forward sample with 2 peer stat tensors. * - * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format + * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of + memory format * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN - * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN + * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in + GBN * */ -cudnn_frontend::ExecutionPlan run_batch_norm_forward( - int64_t *tensorDims, - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *peerDims, - cudnnDataType_t in_out_data_type); +cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims, int64_t *perChannelSum, int64_t *epsilon, + int64_t *peerDims, cudnnDataType_t in_out_data_type); /** * @param xDevPtr input tensor device pointer * @param yDevPtr output tensor device pointer @@ -87,38 +85,26 @@ cudnn_frontend::ExecutionPlan run_batch_norm_forward( * @param epsilon_val episilon value as a double * @param exponential_decay_factor exponential_decay_factor as a value * -**/ -void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, - void *xDevPtr, - void *yDevPtr, - void *scaledevPtr, - void *biasdevPtr, - void *in_meandevPtr, - void *in_vardevPtr, - void *out_meandevPtr, - void *out_vardevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - double epsilon_val, - double exponential_decay_factor, - size_t peer_size, - int rank_id); + **/ +void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void *xDevPtr, void *yDevPtr, void *scaledevPtr, + void *biasdevPtr, void *in_meandevPtr, void *in_vardevPtr, void *out_meandevPtr, + void *out_vardevPtr, void *saved_meandevPtr, void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, double epsilon_val, + double exponential_decay_factor, size_t peer_size, int rank_id); /** * @brief Run a Group BN backward sample with 2 peer stat tensors. * - * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format + * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of + * memory format * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN - * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN - * -*/ -cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, - int64_t *perChannelSum, - int64_t *epsilon, - int64_t *peerDims, - cudnnDataType_t data_type); + * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in + * GBN + * + */ +cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, int64_t *perChannelSum, int64_t *epsilon, + int64_t *peerDims, cudnnDataType_t data_type); /** * @brief Run a Group BN backward sample with 2 peer stat tensors. @@ -138,16 +124,7 @@ cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims, * @param epsilon_val episilon value as a double * */ -void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, - void *xDevPtr, - void *dyDevPtr, - void *scaledevPtr, - void *saved_meandevPtr, - void *saved_inv_vardevPtr, - const std::vector &peer_devPtrs, - void *dxDevPtr, - void *dscaledevPtr, - void *dbiasdevPtr, - double epsilon_val, - size_t peer_size, - int rank_id); +void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void *xDevPtr, void *dyDevPtr, void *scaledevPtr, + void *saved_meandevPtr, void *saved_inv_vardevPtr, + const std::vector &peer_devPtrs, void *dxDevPtr, void *dscaledevPtr, + void *dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id); diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp index e468b9ed0..e27ad5ad0 100644 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ b/apex/contrib/csrc/fmha/fmha_api.cpp @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -25,345 +25,303 @@ * ******************************************************************************/ -#include #include +#include #include "fmha.h" -extern at::Tensor & mha_fill(at::Tensor &self, const at::Tensor &start_index); +extern at::Tensor &mha_fill(at::Tensor &self, const at::Tensor &start_index); void set_params(Fused_multihead_attention_fprop_params ¶ms, // sizes - const size_t b, - const size_t s, - const size_t h, - const size_t d, + const size_t b, const size_t s, const size_t h, const size_t d, // device pointers - void *qkv_packed_d, - void *cu_seqlens_d, - void *o_packed_d, - void *s_d, - float p_dropout) { - - Data_type acc_type = DATA_TYPE_FP32; - Data_type data_type = DATA_TYPE_FP16; - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - // Set the pointers and strides. - params.qkv_ptr = qkv_packed_d; - params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); - params.o_ptr = o_packed_d; - params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); - - params.cu_seqlens = static_cast(cu_seqlens_d); - - // S = softmax(P) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); - - // Set the dimensions. - params.b = b; - params.h = h; - params.s = s; - params.d = d; - - // Set the different scale values. - const float scale_bmm1 = 1.f / sqrtf(d); - constexpr float scale_softmax = 1.f; - constexpr float scale_bmm2 = 1.f; - - set_alpha(params.scale_bmm1, scale_bmm1, data_type); - set_alpha(params.scale_softmax, scale_softmax, acc_type); - set_alpha(params.scale_bmm2, scale_bmm2, data_type); - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - params.rp_dropout = 1.f / params.p_dropout; - TORCH_CHECK(p_dropout < 1.f); - set_alpha(params.scale_dropout, params.rp_dropout, data_type); + void *qkv_packed_d, void *cu_seqlens_d, void *o_packed_d, void *s_d, float p_dropout) { + Data_type acc_type = DATA_TYPE_FP32; + Data_type data_type = DATA_TYPE_FP16; + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.qkv_ptr = qkv_packed_d; + params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); + params.o_ptr = o_packed_d; + params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); + + params.cu_seqlens = static_cast(cu_seqlens_d); + + // S = softmax(P) + params.s_ptr = s_d; + params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); + + // Set the dimensions. + params.b = b; + params.h = h; + params.s = s; + params.d = d; + + // Set the different scale values. + const float scale_bmm1 = 1.f / sqrtf(d); + constexpr float scale_softmax = 1.f; + constexpr float scale_bmm2 = 1.f; + + set_alpha(params.scale_bmm1, scale_bmm1, data_type); + set_alpha(params.scale_softmax, scale_softmax, acc_type); + set_alpha(params.scale_bmm2, scale_bmm2, data_type); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + params.rp_dropout = 1.f / params.p_dropout; + TORCH_CHECK(p_dropout < 1.f); + set_alpha(params.scale_dropout, params.rp_dropout, data_type); } -std::vector -mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, - const int max_seq_len, - const bool is_training, - const bool is_nl, - const bool zero_tensors, - c10::optional gen_) { - - using namespace torch::indexing; - auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || - (dprops->major == 9 && dprops->minor == 0) || - (dprops->major == 10 && dprops->minor == 0) || - (dprops->major == 12 && dprops->minor == 0)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - Launch_params launch_params(dprops, stream, is_training, is_nl); - - int seq_len = 512; - auto launch = &run_fmha_fp16_512_64_sm80; - if( max_seq_len <= 128 ) { - seq_len = 128; - launch = &run_fmha_fp16_128_64_sm80; - } else if( max_seq_len <= 256 ) { - seq_len = 256; - launch = &run_fmha_fp16_256_64_sm80; - } else if( max_seq_len <= 384 ) { - seq_len = 384; - launch = &run_fmha_fp16_384_64_sm80; - } else if( max_seq_len <= 512 ) { - seq_len = 512; - launch = &run_fmha_fp16_512_64_sm80; - } else { - TORCH_CHECK(false); - } - - TORCH_CHECK(qkv.is_cuda()) - TORCH_CHECK(cu_seqlens.is_cuda()) - - TORCH_CHECK(qkv.is_contiguous()) - TORCH_CHECK(cu_seqlens.is_contiguous()) - - TORCH_CHECK(cu_seqlens.dim() == 1); - TORCH_CHECK(qkv.dim() == 4); - - const auto sizes = qkv.sizes(); - - TORCH_CHECK(sizes[THREE_DIM] == 3); - - const int batch_size = cu_seqlens.numel() - 1; - const int total = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); - auto opts = qkv.options(); - - auto ctx = torch::empty({ total, num_heads, head_size }, opts); - - auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); - - if( zero_tensors ) { - mha_fill(ctx, cu_seqlens.index({Slice(-1,None)})); - } - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - - set_params(launch_params.params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - ctx.data_ptr(), - s.data_ptr(), - p_dropout); - - launch(launch_params, /*configure=*/ true); - // number of times random will be generated per thread, to offset philox counter in thc random - // state - int64_t counter_offset = launch_params.elts_per_thread; - at::PhiloxCudaState rng_engine_inputs; - - if( is_training ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); - } - - launch(launch_params, /*configure=*/ false); - - return { ctx, s }; +std::vector mha_fwd( + const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens, // b+1 + const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors, + c10::optional gen_) { + using namespace torch::indexing; + auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || + (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Launch_params launch_params(dprops, stream, is_training, is_nl); + + int seq_len = 512; + auto launch = &run_fmha_fp16_512_64_sm80; + if (max_seq_len <= 128) { + seq_len = 128; + launch = &run_fmha_fp16_128_64_sm80; + } else if (max_seq_len <= 256) { + seq_len = 256; + launch = &run_fmha_fp16_256_64_sm80; + } else if (max_seq_len <= 384) { + seq_len = 384; + launch = &run_fmha_fp16_384_64_sm80; + } else if (max_seq_len <= 512) { + seq_len = 512; + launch = &run_fmha_fp16_512_64_sm80; + } else { + TORCH_CHECK(false); + } + + TORCH_CHECK(qkv.is_cuda()) + TORCH_CHECK(cu_seqlens.is_cuda()) + + TORCH_CHECK(qkv.is_contiguous()) + TORCH_CHECK(cu_seqlens.is_contiguous()) + + TORCH_CHECK(cu_seqlens.dim() == 1); + TORCH_CHECK(qkv.dim() == 4); + + const auto sizes = qkv.sizes(); + + TORCH_CHECK(sizes[THREE_DIM] == 3); + + const int batch_size = cu_seqlens.numel() - 1; + const int total = sizes[TOTAL_DIM]; + const int num_heads = sizes[H_DIM]; + const int head_size = sizes[D_DIM]; + TORCH_CHECK(batch_size > 0); + TORCH_CHECK(head_size == 64); + auto opts = qkv.options(); + + auto ctx = torch::empty({total, num_heads, head_size}, opts); + + auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts); + + if (zero_tensors) { + mha_fill(ctx, cu_seqlens.index({Slice(-1, None)})); + } + + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + set_params(launch_params.params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), + ctx.data_ptr(), s.data_ptr(), p_dropout); + + launch(launch_params, /*configure=*/true); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread; + at::PhiloxCudaState rng_engine_inputs; + + if (is_training) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); + } + + launch(launch_params, /*configure=*/false); + + return {ctx, s}; } - -std::vector -mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size - const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, // probability to drop - const int max_seq_len, // max sequence length to choose the kernel - const bool zero_tensors -) { - using namespace torch::indexing; - auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || - (dprops->major == 9 && dprops->minor == 0) || - (dprops->major == 10 && dprops->minor == 0) || - (dprops->major == 12 && dprops->minor == 0)); - int seq_len = 512; - auto launch = &run_fmha_dgrad_fp16_512_64_sm80; - if( max_seq_len <= 128 ) { - seq_len = 128; - launch = &run_fmha_dgrad_fp16_128_64_sm80; - } else if( max_seq_len <= 256 ) { - seq_len = 256; - launch = &run_fmha_dgrad_fp16_256_64_sm80; - } else if( max_seq_len <= 384 ) { - seq_len = 384; - launch = &run_fmha_dgrad_fp16_384_64_sm80; - } else if( max_seq_len <= 512 ) { - seq_len = 512; - launch = &run_fmha_dgrad_fp16_512_64_sm80; - } else { - TORCH_CHECK(false); - } - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK(qkv.dtype() == torch::kFloat16); - TORCH_CHECK(dout.dtype() == torch::kFloat16); - TORCH_CHECK(softmax.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); - - TORCH_CHECK(qkv.is_cuda()); - TORCH_CHECK(cu_seqlens.is_cuda()); - - TORCH_CHECK(qkv.is_contiguous()); - TORCH_CHECK(cu_seqlens.is_contiguous()); - - TORCH_CHECK(cu_seqlens.dim() == 1); - TORCH_CHECK(qkv.dim() == 4); - - const auto sizes = qkv.sizes(); - - TORCH_CHECK(sizes[THREE_DIM] == 3); - - const int batch_size = cu_seqlens.numel() - 1; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); - - auto dqkv = torch::empty_like(qkv); - - if( zero_tensors ) { - mha_fill(dqkv, cu_seqlens.index({Slice(-1,None)})); - } - - Fused_multihead_attention_fprop_params params; - - set_params(params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - dout.data_ptr(), // we set o_ptr to dout - softmax.data_ptr(), // softmax gets overwritten by dP! - p_dropout); - - // we're re-using these scales - Data_type acc_type = DATA_TYPE_FP32; - set_alpha(params.scale_bmm1, 1.f, acc_type); - set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); - set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); - params.dqkv_ptr = dqkv.data_ptr(); - - launch(params, stream); - return { dqkv, softmax }; +std::vector mha_bwd( + const at::Tensor &dout, // total x num_heads, x head_size + const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i + at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP + const at::Tensor &cu_seqlens, // b+1 + const float p_dropout, // probability to drop + const int max_seq_len, // max sequence length to choose the kernel + const bool zero_tensors) { + using namespace torch::indexing; + auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || + (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); + int seq_len = 512; + auto launch = &run_fmha_dgrad_fp16_512_64_sm80; + if (max_seq_len <= 128) { + seq_len = 128; + launch = &run_fmha_dgrad_fp16_128_64_sm80; + } else if (max_seq_len <= 256) { + seq_len = 256; + launch = &run_fmha_dgrad_fp16_256_64_sm80; + } else if (max_seq_len <= 384) { + seq_len = 384; + launch = &run_fmha_dgrad_fp16_384_64_sm80; + } else if (max_seq_len <= 512) { + seq_len = 512; + launch = &run_fmha_dgrad_fp16_512_64_sm80; + } else { + TORCH_CHECK(false); + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(qkv.dtype() == torch::kFloat16); + TORCH_CHECK(dout.dtype() == torch::kFloat16); + TORCH_CHECK(softmax.dtype() == torch::kFloat16); + TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); + + TORCH_CHECK(qkv.is_cuda()); + TORCH_CHECK(cu_seqlens.is_cuda()); + + TORCH_CHECK(qkv.is_contiguous()); + TORCH_CHECK(cu_seqlens.is_contiguous()); + + TORCH_CHECK(cu_seqlens.dim() == 1); + TORCH_CHECK(qkv.dim() == 4); + + const auto sizes = qkv.sizes(); + + TORCH_CHECK(sizes[THREE_DIM] == 3); + + const int batch_size = cu_seqlens.numel() - 1; + const int num_heads = sizes[H_DIM]; + const int head_size = sizes[D_DIM]; + TORCH_CHECK(batch_size > 0); + TORCH_CHECK(head_size == 64); + + auto dqkv = torch::empty_like(qkv); + + if (zero_tensors) { + mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)})); + } + + Fused_multihead_attention_fprop_params params; + + set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), + dout.data_ptr(), // we set o_ptr to dout + softmax.data_ptr(), // softmax gets overwritten by dP! + p_dropout); + + // we're re-using these scales + Data_type acc_type = DATA_TYPE_FP32; + set_alpha(params.scale_bmm1, 1.f, acc_type); + set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); + set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); + params.dqkv_ptr = dqkv.data_ptr(); + + launch(params, stream); + return {dqkv, softmax}; } -std::vector mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size - const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, // probability to drop - const int max_seq_len, // max sequence length to choose the kernel - const bool zero_tensors -) { +std::vector mha_bwd_nl( + const at::Tensor &dout, // total x num_heads, x head_size + const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i + at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP + const at::Tensor &cu_seqlens, // b+1 + const float p_dropout, // probability to drop + const int max_seq_len, // max sequence length to choose the kernel + const bool zero_tensors) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(qkv.is_cuda()) + TORCH_CHECK(cu_seqlens.is_cuda()) - auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(qkv.is_contiguous()) + TORCH_CHECK(cu_seqlens.is_contiguous()) - TORCH_CHECK(qkv.is_cuda()) - TORCH_CHECK(cu_seqlens.is_cuda()) + TORCH_CHECK(cu_seqlens.dim() == 1); - TORCH_CHECK(qkv.is_contiguous()) - TORCH_CHECK(cu_seqlens.is_contiguous()) + TORCH_CHECK(qkv.dim() == 4); - TORCH_CHECK(cu_seqlens.dim() == 1); + const auto sizes = qkv.sizes(); - TORCH_CHECK(qkv.dim() == 4); + TORCH_CHECK(sizes[THREE_DIM] == 3); - const auto sizes = qkv.sizes(); + const int batch_size = cu_seqlens.numel() - 1; - TORCH_CHECK(sizes[THREE_DIM] == 3); + const int total = sizes[TOTAL_DIM]; + const int num_heads = sizes[H_DIM]; + const int head_size = sizes[D_DIM]; + TORCH_CHECK(batch_size > 0); + TORCH_CHECK(head_size == 64); - const int batch_size = cu_seqlens.numel() - 1; - - const int total = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); + int seq_len = 512; + auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl; - int seq_len = 512; - auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl; + auto opts = qkv.options(); - auto opts = qkv.options(); + auto dqkv = torch::empty_like(qkv); - auto dqkv = torch::empty_like(qkv); + if (zero_tensors) { + dqkv.zero_(); + } - if( zero_tensors ) { - dqkv.zero_(); - } - - int num_chunks = 2; - if( batch_size == 1 ) { - num_chunks = 4; - }else if( batch_size == 2 ) { - num_chunks = 3; - } - auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); + int num_chunks = 2; + if (batch_size == 1) { + num_chunks = 4; + } else if (batch_size == 2) { + num_chunks = 3; + } + auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); - Fused_multihead_attention_fprop_params params; + Fused_multihead_attention_fprop_params params; - set_params(params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - dout.data_ptr(), // o_ptr = dout - softmax.data_ptr(), // softmax gets overwritten by dP! - p_dropout); + set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), + dout.data_ptr(), // o_ptr = dout + softmax.data_ptr(), // softmax gets overwritten by dP! + p_dropout); - params.dkv_ptr = dkv.data_ptr(); + params.dkv_ptr = dkv.data_ptr(); - Data_type acc_type = DATA_TYPE_FP32; - set_alpha(params.scale_bmm1, 1.f, acc_type); - set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); - set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); - params.dqkv_ptr = dqkv.data_ptr(); + Data_type acc_type = DATA_TYPE_FP32; + set_alpha(params.scale_bmm1, 1.f, acc_type); + set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); + set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); + params.dqkv_ptr = dqkv.data_ptr(); - launch(params, num_chunks, stream); + launch(params, num_chunks, stream); - //SPLIT-K reduction of num_chunks dK, dV parts + // SPLIT-K reduction of num_chunks dK, dV parts - // The equivalent of the following Pytorch code: - // using namespace torch::indexing; - // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); - // torch::sum_out(view_out, dkv, 1); + // The equivalent of the following Pytorch code: + // using namespace torch::indexing; + // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); + // torch::sum_out(view_out, dkv, 1); - const int hidden_size = num_heads * head_size; - fmha_run_noloop_reduce( - dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, num_chunks, stream); + const int hidden_size = num_heads * head_size; + fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, + num_chunks, stream); - return { dqkv, softmax, dkv }; + return {dqkv, softmax, dkv}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention for BERT"; - m.def("fwd", &mha_fwd, "Forward pass", py::call_guard()); - m.def("bwd", &mha_bwd, "Backward pass", py::call_guard()); - m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard()); + m.doc() = "Fused Multi-head Self-attention for BERT"; + m.def("fwd", &mha_fwd, "Forward pass", py::call_guard()); + m.def("bwd", &mha_bwd, "Backward pass", py::call_guard()); + m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard()); } diff --git a/apex/contrib/csrc/fmha/src/fmha.h b/apex/contrib/csrc/fmha/src/fmha.h index d01a91505..772156100 100644 --- a/apex/contrib/csrc/fmha/src/fmha.h +++ b/apex/contrib/csrc/fmha/src/fmha.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -28,6 +28,7 @@ #pragma once #include + #include #ifdef OLD_GENERATOR_PATH @@ -36,10 +37,9 @@ #include #endif -#include - #include +#include constexpr int TOTAL_DIM = 0; constexpr int THREE_DIM = 1; @@ -49,115 +49,103 @@ constexpr int D_DIM = 3; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - // The QKV matrices. - void * __restrict__ qkv_ptr; + // The QKV matrices. + void *__restrict__ qkv_ptr; - // The stride between rows of the Q, K and V matrices. - size_t qkv_stride_in_bytes; + // The stride between rows of the Q, K and V matrices. + size_t qkv_stride_in_bytes; - // The number of heads. - int h; + // The number of heads. + int h; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fused_multihead_attention_fprop_params : public Qkv_params { + // The dQKV matrices. + void *__restrict__ dqkv_ptr; - // The dQKV matrices. - void * __restrict__ dqkv_ptr; - - // Temporary for dKV. - void * __restrict__ dkv_ptr; + // Temporary for dKV. + void *__restrict__ dkv_ptr; - // The O matrix (output). - void * __restrict__ o_ptr; + // The O matrix (output). + void *__restrict__ o_ptr; - // The stride between rows of O. - int64_t o_stride_in_bytes; + // The stride between rows of O. + int64_t o_stride_in_bytes; - // The pointer to the S matrix, overwritten by the dP matrix (bwd). - void * __restrict__ s_ptr; - // The stride between rows of the S matrix. - int64_t s_stride_in_bytes; + // The pointer to the S matrix, overwritten by the dP matrix (bwd). + void *__restrict__ s_ptr; + // The stride between rows of the S matrix. + int64_t s_stride_in_bytes; - // The dimensions. - int b, s, d; + // The dimensions. + int b, s, d; - // The scaling factors for the kernel. - uint32_t scale_bmm1, scale_softmax, scale_bmm2; + // The scaling factors for the kernel. + uint32_t scale_bmm1, scale_softmax, scale_bmm2; - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens; + // array of length b+1 holding starting offset of each sequence. + int *__restrict__ cu_seqlens; - // The dropout probability (probability of keeping an activation). - float p_dropout; + // The dropout probability (probability of keeping an activation). + float p_dropout; - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; - // Scale factor of 1 / (1 - p_dropout), in half2. - uint32_t scale_dropout; + // Scale factor of 1 / (1 - p_dropout), in half2. + uint32_t scale_dropout; - // Random state. - at::PhiloxCudaState philox_args; + // Random state. + at::PhiloxCudaState philox_args; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Launch_params{ - Launch_params(cudaDeviceProp * props_, - cudaStream_t stream_, - bool is_training_, - bool is_nl_) - : elts_per_thread(0) - , props(props_) - , stream(stream_) - , is_training(is_training_) - , is_nl(is_nl_) { - } - - size_t elts_per_thread; +template +struct Launch_params { + Launch_params(cudaDeviceProp *props_, cudaStream_t stream_, bool is_training_, bool is_nl_) + : elts_per_thread(0), props(props_), stream(stream_), is_training(is_training_), is_nl(is_nl_) {} - cudaDeviceProp * props; + size_t elts_per_thread; - cudaStream_t stream; + cudaDeviceProp *props; - bool is_training; + cudaStream_t stream; - Kernel_params params; - int num_full_heads; - int num_main_groups; - int heads_last_wave; - int main_steps; - int rest_steps; - bool is_nl; + bool is_training; + Kernel_params params; + int num_full_heads; + int num_main_groups; + int heads_last_wave; + int main_steps; + int rest_steps; + bool is_nl; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, const bool configure); +void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, + const bool configure); +void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, + const bool configure); +void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, + const bool configure); +void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, + const bool configure); void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); -void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, const int num_chunks, cudaStream_t stream); - -void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream); - -void fmha_run_noloop_reduce(void *out, - const void *in, - const int *cu_seqlens, - const int hidden_size, - const int batch_size, - const int total, - const int num_chunks, - cudaStream_t stream); +void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, + const int num_chunks, cudaStream_t stream); +void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, + cudaStream_t stream); +void fmha_run_noloop_reduce(void *out, const void *in, const int *cu_seqlens, const int hidden_size, + const int batch_size, const int total, const int num_chunks, cudaStream_t stream); diff --git a/apex/contrib/csrc/fmha/src/fmha/gemm.h b/apex/contrib/csrc/fmha/src/fmha/gemm.h index 62529a2c5..5bd780f29 100644 --- a/apex/contrib/csrc/fmha/src/fmha/gemm.h +++ b/apex/contrib/csrc/fmha/src/fmha/gemm.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -35,34 +35,33 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > +template struct Fragment_base_ { - - // The data type. - using Data_type = Data_type_; - // default input type - using Input_type_ = Data_type_; - // Does it store the array of elements. - enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; - // The number of elements. - enum { NUM_ELTS = NUM_ELTS_ }; - // The size of element in bits. - enum { BITS_PER_ELT = BITS_PER_ELT_ }; - // The size of byte of a single register. - enum { BYTES_PER_REG = 4 }; - // The size in bits. - enum { BITS_PER_REG = BYTES_PER_REG * 8 }; - // The number of registers needed to store the fragment. - enum { NUM_REGS = Div_up::VALUE }; - // The size in bytes (as returned by sizeof(Fragment_base<>). - enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; - // The alignment. - enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; + // The data type. + using Data_type = Data_type_; + // default input type + using Input_type_ = Data_type_; + // Does it store the array of elements. + enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; + // The number of elements. + enum { NUM_ELTS = NUM_ELTS_ }; + // The size of element in bits. + enum { BITS_PER_ELT = BITS_PER_ELT_ }; + // The size of byte of a single register. + enum { BYTES_PER_REG = 4 }; + // The size in bits. + enum { BITS_PER_REG = BYTES_PER_REG * 8 }; + // The number of registers needed to store the fragment. + enum { NUM_REGS = Div_up::VALUE }; + // The size in bytes (as returned by sizeof(Fragment_base<>). + enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; + // The alignment. + enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The type of the elements. typename Data_type_, // The number of elements. @@ -70,141 +69,125 @@ template< // The alignment if you want to force a value -- use 0 otherwise. int ALIGNMENT_ = 0, // The base class. - typename Base_ = Fragment_base_ -> + typename Base_ = Fragment_base_ > struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { - - // The size of a load/store. - enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; - - // Clear the fragment. Using PTX in that code seems to produce better SASS... - inline __device__ void clear() { - #pragma unroll - for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); - } + // The size of a load/store. + enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; + + // Clear the fragment. Using PTX in that code seems to produce better SASS... + inline __device__ void clear() { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) :); } + } - // Immutable access to a register. - inline __device__ const uint32_t& reg(int ii) const { - return this->regs_[ii]; - } + // Immutable access to a register. + inline __device__ const uint32_t& reg(int ii) const { return this->regs_[ii]; } - // Mutable access to a register. - inline __device__ uint32_t& reg(int ii) { - return this->regs_[ii]; - } + // Mutable access to a register. + inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } - uint32_t regs_[Base_::NUM_REGS]; + uint32_t regs_[Base_::NUM_REGS]; - // Immutable access to the elements. - inline __device__ const Data_type_& elt(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } + // Immutable access to the elements. + inline __device__ const Data_type_& elt(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } - // Mutable access to the elements. - inline __device__ Data_type_& elt(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } + // Mutable access to the elements. + inline __device__ Data_type_& elt(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } - // Immutable access to the elements with a cast. - template< typename Cast_type > - inline __device__ const Cast_type& elt_as(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } + // Immutable access to the elements with a cast. + template + inline __device__ const Cast_type& elt_as(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } - // Mutable access to the elements. - template< typename Cast_type > - inline __device__ Cast_type& elt_as(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } + // Mutable access to the elements. + template + inline __device__ Cast_type& elt_as(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } - // Add another fragment. - inline __device__ void add(const Fragment &other) { - #pragma unroll - for( int ii = 0; ii < NUM_ELTS_; ++ii ) { - this->elt(ii) += other.elt(ii); - } + // Add another fragment. + inline __device__ void add(const Fragment& other) { +#pragma unroll + for (int ii = 0; ii < NUM_ELTS_; ++ii) { + this->elt(ii) += other.elt(ii); } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Layout > -struct Fragment_a : public Fragment { -}; +template +struct Fragment_a : public Fragment {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Layout > -struct Fragment_b : public Fragment { -}; +template +struct Fragment_b : public Fragment {}; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fragment_accumulator : public Fragment { - - // The base class. - using Base = Fragment; - - // Add two fragments. - template< typename Other_fragment_ > - inline __device__ void add(const Other_fragment_ &other) { - for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - this->elt(ii) = this->elt(ii) + other.elt(ii); - } - } - - // Do the HMMA. - template< typename Layout_a, typename Layout_b > - inline __device__ void mma(const Fragment_a &a, - const Fragment_b &b) { - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(0)), "r"(b.reg(1))); - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(2)), "r"(b.reg(3))); + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(const Other_fragment_& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); } + } + // Do the HMMA. + template + inline __device__ void mma(const Fragment_a& a, const Fragment_b& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Fragment, int M, int N > +template inline __device__ void clear(Fragment (&frag)[M][N]) { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - frag[mi][ni].clear(); - } +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + frag[mi][ni].clear(); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Accumulator_type, int WARPS_K > -struct Clear_accumulator { -}; +template +struct Clear_accumulator {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int WARPS_K > +template struct Clear_accumulator { - template< typename Acc, int M, int N > + template static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { fmha::clear(acc); } @@ -212,21 +195,20 @@ struct Clear_accumulator { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { - - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - acc[mi][ni].mma(a[mi], b[ni]); - } +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + acc[mi][ni].mma(a[mi], b[ni]); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The number of rows in the CTA tile. int M_, // The number of cols in the CTA tile. @@ -240,46 +222,44 @@ template< // The number of warps in the K dimension of the GEMM loop. int WARPS_K_> struct Cta_tile_ { - - enum { M = M_, N = N_, K = K_ }; - // The number of warps. - enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; - // The number of warps per CTA. - enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; - // The number of threads per warp. - enum { THREADS_PER_WARP = 32 }; - // The number of threads per CTA. - enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; + enum { M = M_, N = N_, K = K_ }; + // The number of warps. + enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; + // The number of warps per CTA. + enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; + // The number of threads per warp. + enum { THREADS_PER_WARP = 32 }; + // The number of threads per CTA. + enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Hmma_tile { - // The number of elements computed with a single warp-MMA. - enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 }; - - // The number of elements computed with a single CTA-MMA. - enum { - M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, - N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, - K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K - }; - - // The number of MMAs needed to compute the GEMM. - enum { - MMAS_M = Div_up::VALUE, - MMAS_N = Div_up::VALUE, - MMAS_K = Div_up::VALUE, - }; - - // The number of elements computed per warp. - enum { - M_PER_WARP = MMAS_M * M_PER_MMA, - N_PER_WARP = MMAS_N * N_PER_MMA, - K_PER_WARP = MMAS_K * K_PER_MMA, - }; - + // The number of elements computed with a single warp-MMA. + enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = Div_up::VALUE, + MMAS_N = Div_up::VALUE, + MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp. + enum { + M_PER_WARP = MMAS_M * M_PER_MMA, + N_PER_WARP = MMAS_N * N_PER_MMA, + K_PER_WARP = MMAS_K * K_PER_MMA, + }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -296,18 +276,14 @@ constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template using Cta_tile_extd = Cta_tile_; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, - Cta_tile_::WARPS_M, - Cta_tile_::WARPS_N, - Cta_tile_::WARPS_K>; +template +using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, + Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h b/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h index 5c86dd84e..ba55c1a16 100644 --- a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h +++ b/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,7 +31,7 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The number of bits per element. @@ -41,416 +41,394 @@ template< // The number of columns. int COLS, // The number of matrics. - int NUM_MATS = 3 -> + int NUM_MATS = 3> struct Gmem_tile_qkv { - - // The size of each LDG. - enum { BYTES_PER_LDG = 16 }; - // The size of a row in bytes. - enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; - - // The number of threads to load a "row" of the matrix. - enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; - - // The number of "rows" loaded per LDG. - enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // The number of LDGs needed to load a chunk of the Q matrix. - enum { LDGS = fmha::Div_up::VALUE }; - - // Ctor. - template< typename Params, typename BInfo > - inline __device__ Gmem_tile_qkv(const Params ¶ms, const int qkv_offset, const BInfo &binfo, const int tidx) - : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) - , actual_seqlen(binfo.actual_seqlen) - , qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable the loads. - row_ = row; - - // The row offset in the batched GEMM. For each seq element, we store QKV in that order. - int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; - // Add the block index. - row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; - - // Assemble the final pointer. - qkv_ptr_ += row_offset + col * BYTES_PER_LDG; - } - - // Store data to shared memory. - template< typename Smem_tile > - inline __device__ void commit(Smem_tile &smem_tile) { - smem_tile.store(fetch_); - } - - // Load data from memory. - template< typename Smem_tile > - inline __device__ void load(Smem_tile &smem_tile) { - const void *ptrs[LDGS]; - uint32_t preds[LDGS]; - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); - fetch_[ii] = make_uint4(0, 0, 0, 0); - } - - // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) - Ldg_functor fct(fetch_, ptrs); - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - fct.load(ii, preds[ii]); - } - } - - // Store data to memory. - inline __device__ void store(const uint4 (&data)[LDGS]) { - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { - fmha::stg(ptr, data[ii]); - } - } + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // Ctor. + template + inline __device__ Gmem_tile_qkv(const Params ¶ms, const int qkv_offset, const BInfo &binfo, const int tidx) + : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes), + actual_seqlen(binfo.actual_seqlen), + qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable the loads. + row_ = row; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; + // Add the block index. + row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; + + // Assemble the final pointer. + qkv_ptr_ += row_offset + col * BYTES_PER_LDG; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile &smem_tile) { + smem_tile.store(fetch_); + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile &smem_tile) { + const void *ptrs[LDGS]; + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); + fetch_[ii] = make_uint4(0, 0, 0, 0); } - // Move the pointer to the next location. - inline __device__ void move() { - qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; - actual_seqlen -= ROWS; + // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) + Ldg_functor fct(fetch_, ptrs); +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fct.load(ii, preds[ii]); } - - inline __device__ void move(int steps) { - qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; - actual_seqlen -= ROWS * steps; + } + + // Store data to memory. + inline __device__ void store(const uint4 (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) { + fmha::stg(ptr, data[ii]); + } } - - // The stride between rows for the QKV matrice. - int64_t params_qkv_stride_in_bytes_; - // The pointer. - char *qkv_ptr_; - // The fetch registers. - uint4 fetch_[LDGS]; - // Keep track of the row the thread is processing as we move the tile. - int row_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen; + } + + // Move the pointer to the next location. + inline __device__ void move() { + qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; + actual_seqlen -= ROWS; + } + + inline __device__ void move(int steps) { + qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; + actual_seqlen -= ROWS * steps; + } + + // The stride between rows for the QKV matrice. + int64_t params_qkv_stride_in_bytes_; + // The pointer. + char *qkv_ptr_; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row the thread is processing as we move the tile. + int row_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Cta_tile > +template struct Gmem_tile_o { - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The size of each element. - enum { BYTES_PER_ELEMENT = 2 }; - // The size of a row in bytes. - enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; - - // The number of threads to store a "row" of the matrix. - enum { THREADS_PER_ROW = 16 }; - // The size of each STG. - enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW }; - - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - enum { ROWS = Cta_tile::M }; - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; - // The number of outter loop for the stores. - enum { LOOPS = ROWS / ROWS_PER_LOOP }; - - // The number of "rows" stored per STG. - enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Do we have to guard against partial writes/reads. - enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; - // The number of STGs needed to store a chunk of the Q matrix. - enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; - // The number of STGs needed to store a chunk of the Q matrix in total. - enum { STGS = STGS_PER_LOOP * LOOPS }; - - // Ctor. - template - inline __device__ Gmem_tile_o(const Params ¶ms, const BInfo &binfo, int tidx) - : params_o_stride_in_bytes_(params.o_stride_in_bytes) - , actual_seqlen_(binfo.actual_seqlen) - , o_ptr_(reinterpret_cast(params.o_ptr)) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable loads. - row_ = row; - - // The row offset in the batched GEMM. - int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; - // Assemble the final pointer. - o_ptr_ += row_offset + col * BYTES_PER_STG; - - // Is that thread active on the last STG? - if( HAS_INCOMPLETE_STG ) { - is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; - } - } - - // Store data to global memory. - inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { - - #pragma unroll - for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { - int jj = mi * STGS_PER_LOOP + ii; - if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) { - break; - } - - float x = reinterpret_cast(src[ii].x); - float y = reinterpret_cast(src[ii].y); - float z = reinterpret_cast(src[ii].z); - float w = reinterpret_cast(src[ii].w); - uint2 out = float4_to_half4(x, y, z, w); - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); - } - } - } - - // Move the pointer to the next location. - inline __device__ void move() { - row_ += ROWS; - o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 2 }; + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = 16 }; + // The size of each STG. + enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; + // The number of outter loop for the stores. + enum { LOOPS = ROWS / ROWS_PER_LOOP }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Ctor. + template + inline __device__ Gmem_tile_o(const Params ¶ms, const BInfo &binfo, int tidx) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(binfo.actual_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable loads. + row_ = row; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; + // Assemble the final pointer. + o_ptr_ += row_offset + col * BYTES_PER_STG; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; } - - inline __device__ void move(const int steps) { - row_ += ROWS * steps; - o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; + } + + // Store data to global memory. + inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { + break; + } + + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + uint2 out = float4_to_half4(x, y, z, w); + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_)) { + fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); + } } - - // The stride between rows for the QKV matrice. - int64_t params_o_stride_in_bytes_; - // The pointer. - char *o_ptr_; - // Is the thread active for the last STG? - int is_active_for_last_stg_; - // Keep track of the row to disable loads. - int row_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen_; + } + + // Move the pointer to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + inline __device__ void move(const int steps) { + row_ += ROWS * steps; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char *o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + // Keep track of the row to disable loads. + int row_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Cta_tile, int BYTES_PER_ELEMENT > +template struct Gmem_tile_mma_sd { - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // Each STG stores 8 elements. - enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 }; - // The number of MMAs in the M dimension. - enum { MMAS_M = Mma_tile::MMAS_M }; - // The number of MMAs in the N dimension. - enum { MMAS_N = Mma_tile::MMAS_N }; - // The number of rows computed per MMA per thread block. - enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA }; - // The number of cols computed per MMA per thread block. - enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA }; - // The number of threads per block. - enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA }; - // The size of each row in bytes. I.e. how many bytes are stored per STG. - enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG }; - // The fixed sequence length. - enum { SEQLEN = Cta_tile::N }; - // The distance between two blocks (in bytes). - enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT }; - // The distance between elements stored per loop (in bytes). - enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW }; - - // The type of elements stored per STG. - using Type = typename fmha::Uint_from_size_in_bytes::Type; - - // Ctor. - template - inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) - : ptr_(static_cast(ptr)) { - - // The block index. - size_t bidx = bidb * params.h + bidh; - - // Set store location for each thread at the beginning of the loop - ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG; - } - - // Store to global memory. - inline __device__ void store(const Type &data, const int mi, const int ni) { - size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::stg(ptr_ + offset, data); - } - - // Load from global memory. - inline __device__ void load(Type &data, const int mi, const int ni) { - size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::ldg(data, ptr_ + offset); - } - - // Move to the next tile. - inline __device__ void move() { - ptr_ += LOOP_STRIDE_BYTES; - } - inline __device__ void move(const int steps) { - ptr_ += LOOP_STRIDE_BYTES * steps; - } - - // The pointer in global memory. - char *ptr_; + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // Each STG stores 8 elements. + enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 }; + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::MMAS_N }; + // The number of rows computed per MMA per thread block. + enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA }; + // The number of cols computed per MMA per thread block. + enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA }; + // The number of threads per block. + enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA }; + // The size of each row in bytes. I.e. how many bytes are stored per STG. + enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG }; + // The fixed sequence length. + enum { SEQLEN = Cta_tile::N }; + // The distance between two blocks (in bytes). + enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT }; + // The distance between elements stored per loop (in bytes). + enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW }; + + // The type of elements stored per STG. + using Type = typename fmha::Uint_from_size_in_bytes::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) + : ptr_(static_cast(ptr)) { + // The block index. + size_t bidx = bidb * params.h + bidh; + + // Set store location for each thread at the beginning of the loop + ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG; + } + + // Store to global memory. + inline __device__ void store(const Type &data, const int mi, const int ni) { + size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::stg(ptr_ + offset, data); + } + + // Load from global memory. + inline __device__ void load(Type &data, const int mi, const int ni) { + size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::ldg(data, ptr_ + offset); + } + + // Move to the next tile. + inline __device__ void move() { ptr_ += LOOP_STRIDE_BYTES; } + inline __device__ void move(const int steps) { ptr_ += LOOP_STRIDE_BYTES * steps; } + + // The pointer in global memory. + char *ptr_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +template > struct Gmem_tile_mma_s : public Base { - - // The number of mmas in the vertical dimension. - enum { M = Base::MMAS_M }; - // The number of mmas in the horizontal dimension. - enum { N = Base::MMAS_N }; - // The type of the vectors stored by each STG. - using Type = typename Base::Type; - - // Ctor. - template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) - : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { - } - - // Store to global memory. - template - inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - - float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; - float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; - float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; - float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; - - float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; - float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; - float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; - float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; - - uint4 dst; - dst.x = fmha::float2_to_half2(tmp00, tmp01); - dst.y = fmha::float2_to_half2(tmp02, tmp03); - dst.z = fmha::float2_to_half2(tmp10, tmp11); - dst.w = fmha::float2_to_half2(tmp12, tmp13); - if( mask.is_valid(mi, ni, 0, 0) ) { - Base::store(dst, mi, ni); - } - } + // The number of mmas in the vertical dimension. + enum { M = Base::MMAS_M }; + // The number of mmas in the horizontal dimension. + enum { N = Base::MMAS_N }; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info &binfo, const int tidx) + : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {} + + // Store to global memory. + template + inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; + float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; + float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; + float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; + + float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; + float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; + float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; + float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; + + uint4 dst; + dst.x = fmha::float2_to_half2(tmp00, tmp01); + dst.y = fmha::float2_to_half2(tmp02, tmp03); + dst.z = fmha::float2_to_half2(tmp10, tmp11); + dst.w = fmha::float2_to_half2(tmp12, tmp13); + if (mask.is_valid(mi, ni, 0, 0)) { + Base::store(dst, mi, ni); } + } } - - // Store to global memory. - template - inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 dst; - dst.x = frag[ni][mi].reg(0); - dst.y = frag[ni][mi].reg(2); - dst.z = frag[ni][mi].reg(1); - dst.w = frag[ni][mi].reg(3); - if( mask.any_valid(mi, ni) ) { - Base::store(dst, mi, ni); - } - } + } + + // Store to global memory. + template + inline __device__ void store(const Fragment (&frag)[N][M], const Mask &mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 dst; + dst.x = frag[ni][mi].reg(0); + dst.y = frag[ni][mi].reg(2); + dst.z = frag[ni][mi].reg(1); + dst.w = frag[ni][mi].reg(3); + if (mask.any_valid(mi, ni)) { + Base::store(dst, mi, ni); } + } } - - // Load from global memory. - template - inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - regs[mi][ni] = make_uint4(0, 0, 0, 0); - if( mask.any_valid(mi, ni) ) { - Base::load(regs[mi][ni], mi, ni); - } - } + } + + // Load from global memory. + template + inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + regs[mi][ni] = make_uint4(0, 0, 0, 0); + if (mask.any_valid(mi, ni)) { + Base::load(regs[mi][ni], mi, ni); } + } } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The base class. - typename Base = fmha::Gmem_tile_qkv -> + typename Base = fmha::Gmem_tile_qkv > struct Gmem_tile_dout : public Base { + // Ctor. + template + inline __device__ Gmem_tile_dout(const Params ¶ms, const BInfo &binfo, int tidx) : Base(params, 0, binfo, tidx) { + this->qkv_ptr_ = reinterpret_cast(params.o_ptr); + this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move - // Ctor. - template - inline __device__ Gmem_tile_dout(const Params ¶ms, const BInfo &binfo, int tidx) - : Base(params, 0, binfo, tidx) { + // Compute the position of the thread in the row. + int col = tidx % Base::THREADS_PER_ROW; - this->qkv_ptr_ = reinterpret_cast(params.o_ptr); - this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move + // The row offset in the batched GEMM. For each seq element, we store O in that order. + int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG; - } + // Assemble the final pointer. + this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Cta_tile, typename Base = fmha::Gmem_tile_o > +template > struct Gmem_tile_dq : public Base { - - // Ctor. - template - inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx) - : Base(params, binfo, tidx) { - this->o_ptr_ = reinterpret_cast(params.dqkv_ptr); - this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move - - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes + - (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG; - } + // Ctor. + template + inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx) : Base(params, binfo, tidx) { + this->o_ptr_ = reinterpret_cast(params.dqkv_ptr); + this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move + + // Compute the position of the thread in the row. + int col = tidx % Base::THREADS_PER_ROW; + + // The row offset in the batched GEMM. For each seq element, we store O in that order. + int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes + + (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; + + // Assemble the final pointer. + this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha - diff --git a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h b/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h index d51b47c53..7b0c00d16 100644 --- a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h +++ b/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,71 +27,73 @@ #pragma once +#include "gmem_tile.h" +#include "smem_tile.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct FMHA_kernel_traits { - - // The CTA description for the 1st GEMM. - using Cta_tile_p = fmha::Cta_tile_extd; - // The CTA description for the 2nd GEMM. - using Cta_tile_o = fmha::Cta_tile_extd; - - // Do we use one buffer for K and V. - enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u }; - // Do we keep K in registers. - enum { K_IN_REGS = (FLAGS & 0x10u) == 0u }; - - // The global memory tile to load Q. - using Gmem_tile_q = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle Q. - using Smem_tile_q = fmha::Smem_tile_a; - - // The global memory tile to load K. - using Gmem_tile_k = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle K. - using Smem_tile_k = fmha::Smem_tile_b; - - // The global memory tile to load V. - using Gmem_tile_v = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle V. - using Smem_tile_v = fmha::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = fmha::Gmem_tile_o; - // The shared memory tile for O. - using Smem_tile_o = fmha::Smem_tile_o; - - // The global memory tile to load/store S. - using Gmem_tile_s = fmha::Gmem_tile_mma_s; - - // The shared memory tile to transpose S. - using Smem_tile_st = fmha::Smem_tile_mma_transposed; - - using Gmem_tile_do = fmha::Gmem_tile_dout; - - // Make sure the number of threads match. - static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); - - // The number of threads. - enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; - // Make sure the number of threads matches both CTAs. - static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); - - // The amount of shared memory needed to load Q and K. - enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; - // The extra amount of shared memory needed to load V. - enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; - // The amount of shared memory needed for Q, K and V.. - enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; - // The amount of shared memory needed to load Q and store O. - enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; - - // The amount of shared memory needed for Q, K, V and O. - enum { BYTES_PER_SMEM = fmha::Max::VALUE }; - // Make sure we have enough shared memory. - static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); + // The CTA description for the 1st GEMM. + using Cta_tile_p = fmha::Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = fmha::Cta_tile_extd; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u }; + // Do we keep K in registers. + enum { K_IN_REGS = (FLAGS & 0x10u) == 0u }; + + // The global memory tile to load Q. + using Gmem_tile_q = fmha::Gmem_tile_qkv; + + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = fmha::Gmem_tile_o; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + + // The global memory tile to load/store S. + using Gmem_tile_s = fmha::Gmem_tile_mma_s; + + // The shared memory tile to transpose S. + using Smem_tile_st = fmha::Smem_tile_mma_transposed; + + using Gmem_tile_do = fmha::Gmem_tile_dout; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + // The amount of shared memory needed to load Q and store O. + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha/mask.h b/apex/contrib/csrc/fmha/src/fmha/mask.h index 020258a02..2ee2b8fd8 100644 --- a/apex/contrib/csrc/fmha/src/fmha/mask.h +++ b/apex/contrib/csrc/fmha/src/fmha/mask.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -29,53 +29,46 @@ namespace fmha { - -template +template struct Mask { - using Mma_tile = fmha::Hmma_tile; - - template - __device__ Mask(const Params ¶ms, const BInfo &blockInfo, int tidx) { - - actual_seqlen = blockInfo.actual_seqlen; + using Mma_tile = fmha::Hmma_tile; - const int warp = tidx / Cta_tile::THREADS_PER_WARP; - const int lane = tidx % Cta_tile::THREADS_PER_WARP; + template + __device__ Mask(const Params ¶ms, const BInfo &blockInfo, int tidx) { + actual_seqlen = blockInfo.actual_seqlen; - static_assert(Cta_tile::WARPS_K == 1, ""); + const int warp = tidx / Cta_tile::THREADS_PER_WARP; + const int lane = tidx % Cta_tile::THREADS_PER_WARP; - // find the warp in the Cta tile - const int warp_n = (warp / Cta_tile::WARPS_M); - const int warp_m = (warp % Cta_tile::WARPS_M); - // decompose warp into 8x4 tile - const int quad = lane / 4; - const int tid = (lane % 4) * 2; - row = warp_m * 16 + quad; - col = warp_n * 16 + tid; - } + static_assert(Cta_tile::WARPS_K == 1, ""); - inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + row = warp_m * 16 + quad; + col = warp_n * 16 + tid; + } - // ii and jj iterate over the 2x4 fragment - const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; - //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; - return col_valid; - // return row_valid && col_valid; - } + inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { + // ii and jj iterate over the 2x4 fragment + const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; + //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; + return col_valid; + // return row_valid && col_valid; + } - //BERT Mask: if upper left is invalid, none are valid - inline __device__ bool any_valid(int mi, int ni) const { - return is_valid(mi, ni, 0, 0); - } + // BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } - inline __device__ void load(int it) { - row_offset = it * Cta_tile::M + row; - } - int row_offset; + inline __device__ void load(int it) { row_offset = it * Cta_tile::M + row; } + int row_offset; - int row; - int col; - int actual_seqlen; + int row; + int col; + int actual_seqlen; }; } // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h b/apex/contrib/csrc/fmha/src/fmha/smem_tile.h index 80879140a..6428c7d2d 100644 --- a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h +++ b/apex/contrib/csrc/fmha/src/fmha/smem_tile.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,330 +27,306 @@ #pragma once -#include #include +#include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The description of the tile computed by this CTA. - typename Cta_tile, + typename Cta_tile, // The number of rows in the 2D shared memory buffer. - int M_, + int M_, // The number of cols. - int N_, + int N_, // The size in bits of each element. - int BITS_PER_ELEMENT_, + int BITS_PER_ELEMENT_, // The number of bytes per STS. int BYTES_PER_STS_ = 16, // The number of buffers. (Used in multistage and double buffer cases.) int BUFFERS_PER_TILE_ = 1, // Do we enable the fast path for LDS.128 and friends. - int ENABLE_LDS_FAST_PATH_ = 0, - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. int ROWS_PER_XOR_PATTERN_ = 8, - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. int COLS_PER_XOR_PATTERN_ = 1, // Use or not predicates - bool USE_PREDICATES_ = true -> + bool USE_PREDICATES_ = true> struct Smem_tile_without_skews { - - // The size in bits of each element. - enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; - // The size in bytes of a single STS. - enum { BYTES_PER_STS = BYTES_PER_STS_ }; - // The number of elements per STS. - enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; - // To support arbitrary N, we pad some values to a power-of-2. - enum { N_WITH_PADDING = Next_power_of_two::VALUE }; - // The number of bytes per row without packing of rows. - enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; - // The number of bytes per row -- we want at least 128B per row. - enum { BYTES_PER_ROW = Max::VALUE }; - // The number of rows in shared memory (two rows may be packed into a single one). - enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; - - // The number of threads per row. - enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; - // The number of threads per row. - enum { THREADS_PER_ROW = Min::VALUE }; - - // The number of STS per row. - enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; - // It must be at least one. - static_assert(STS_PER_ROW >= 1, ""); - // The number of rows written with a single STS. - enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) - static_assert(ROWS_PER_STS >= 1, ""); - // The number of STS needed to store all rows. - enum { STS_PER_COL = Div_up::VALUE }; - // The number of STS in total. - enum { STS = STS_PER_COL * STS_PER_ROW }; - - // The size of one buffer in bytes in shared memory. - enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; - // The number of buffers. - enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; - // The size in bytes of total buffers. - enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; - // The boundary for smem_read_offset and smem_write_offset increment. - enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; - - // Do we enable the LDS.128 fast path? - enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; - static_assert(ENABLE_LDS_FAST_PATH == 0); - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. - enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. - enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; - // Use or not predicates - enum { USE_PREDICATES = USE_PREDICATES_ }; - - // The type of elements that are stored in shared memory by each thread. - using Store_type = typename Uint_from_size_in_bytes::Type; - - // Ctor. - inline __device__ Smem_tile_without_skews(void *smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { - - // The row written by a thread. See doc/mma_smem_layout.xlsx. - int smem_write_row = tidx / THREADS_PER_ROW; - - // The XOR pattern. - int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; - // Compute the column and apply the XOR pattern. - int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; - - // The offset. - this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS; - - // TODO: Why not merge it with the read offset? - this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); - this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); - } - - // Compute the store pointers. - template< int N > - inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - // Decompose the STS into row/col. - int row = ii / STS_PER_ROW; - int col = ii % STS_PER_ROW; - - // Assemble the offset. - int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW; - - // Take the column into account. - if( STS_PER_ROW > 1 ) { - offset += col*THREADS_PER_ROW*BYTES_PER_STS; - } - - // Apply the XOR pattern if needed. - if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) { - const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; - offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; - } - - // Assemble the final pointer :) - ptrs[ii] = smem_ + offset + smem_write_buffer_; - } - } - - inline __device__ void debug_reset() { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val = 0x0; - sts(val, smem_ + row*BYTES_PER_ROW + col + buffer); - } - } - } - } + // The size in bits of each element. + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + // The size in bytes of a single STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + // The number of elements per STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + // To support arbitrary N, we pad some values to a power-of-2. + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + // The number of bytes per row without packing of rows. + enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; + // The number of bytes per row -- we want at least 128B per row. + enum { BYTES_PER_ROW = Max::VALUE }; + // The number of rows in shared memory (two rows may be packed into a single one). + enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; + + // The number of threads per row. + enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; + // The number of threads per row. + enum { THREADS_PER_ROW = Min::VALUE }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + // It must be at least one. + static_assert(STS_PER_ROW >= 1, ""); + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) + static_assert(ROWS_PER_STS >= 1, ""); + // The number of STS needed to store all rows. + enum { STS_PER_COL = Div_up::VALUE }; + // The number of STS in total. + enum { STS = STS_PER_COL * STS_PER_ROW }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // Do we enable the LDS.128 fast path? + enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; + static_assert(ENABLE_LDS_FAST_PATH == 0); + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; + // Use or not predicates + enum { USE_PREDICATES = USE_PREDICATES_ }; + + // The type of elements that are stored in shared memory by each thread. + using Store_type = typename Uint_from_size_in_bytes::Type; + + // Ctor. + inline __device__ Smem_tile_without_skews(void *smem, int tidx) : smem_(__nvvm_get_smem_pointer(smem)) { + // The row written by a thread. See doc/mma_smem_layout.xlsx. + int smem_write_row = tidx / THREADS_PER_ROW; + + // The XOR pattern. + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; + // Compute the column and apply the XOR pattern. + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + + // The offset. + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // TODO: Why not merge it with the read offset? + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + int col = ii % STS_PER_ROW; + + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + + // Take the column into account. + if (STS_PER_ROW > 1) { + offset += col * THREADS_PER_ROW * BYTES_PER_STS; + } + + // Apply the XOR pattern if needed. + if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) { + const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; + offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; + } + + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; } - - // Print the content of the tile (only for debug ;)). - inline __device__ void debug_print() const { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val; - lds(val, smem_ + row*BYTES_PER_ROW + col + buffer); - printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", - blockIdx.x, - blockIdx.y, - blockIdx.z, - smem_, - buffer, - row, - col, - val); - } - } - } + } + + inline __device__ void debug_reset() { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val = 0x0; + sts(val, smem_ + row * BYTES_PER_ROW + col + buffer); + } } + } } - - // Move the read offset to next buffer. - inline __device__ void move_to_next_read_buffer() { - if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_buffer_ += BYTES_PER_BUFFER; + } + + // Print the content of the tile (only for debug ;)). + inline __device__ void debug_print() const { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val; + lds(val, smem_ + row * BYTES_PER_ROW + col + buffer); + printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", blockIdx.x, + blockIdx.y, blockIdx.z, smem_, buffer, row, col, val); + } } + } } - - // Move the read offset to next buffer. TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer() { - this->move_to_next_read_buffer(); + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += BYTES_PER_BUFFER; } + } - // Move the read offset to next N buffer (circular-buffer). - inline __device__ void move_to_next_read_buffer(int N) { - if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_buffer_ += N * BYTES_PER_BUFFER; - this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; - } - } - - // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer(int N) { - this->move_to_next_read_buffer(N); - } - - // Move the write offset to next buffer. - inline __device__ void move_to_next_write_buffer() { - if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_write_buffer_ += BYTES_PER_BUFFER; - } - } + // Move the read offset to next buffer. TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer() { this->move_to_next_read_buffer(); } - // Move the write offset to next buffer. TODO: Remove that member function! - inline __device__ void move_next_write_buffer() { - this->move_to_next_write_buffer(); + // Move the read offset to next N buffer (circular-buffer). + inline __device__ void move_to_next_read_buffer(int N) { + if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += N * BYTES_PER_BUFFER; + this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; } + } - // Move the read offset. - inline __device__ void move_read_offset(int delta) { - this->smem_read_offset_ += delta; - } + // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer(int N) { this->move_to_next_read_buffer(N); } - // Move the write offset. - inline __device__ void move_write_offset(int delta) { - this->smem_write_offset_ += delta; + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_write_buffer_ += BYTES_PER_BUFFER; } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data); - } - - // Store to the tile in shared memory. - template< int N, int M > - inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { - this->store(data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { - uint32_t tmp[1] = { preds }; - this->store(gmem_ptrs, tmp); - } - - // The shared memory pointer. - uint32_t smem_; - // The read offset. Reserve 4 offsets if needed. - int smem_read_offset_; - // The write offset. - int smem_write_offset_; - // The buffer base offset for read. - int smem_read_buffer_; - // The buffer base offset for write. - int smem_write_buffer_; + } + + // Move the write offset to next buffer. TODO: Remove that member function! + inline __device__ void move_next_write_buffer() { this->move_to_next_write_buffer(); } + + // Move the read offset. + inline __device__ void move_read_offset(int delta) { this->smem_read_offset_ += delta; } + + // Move the write offset. + inline __device__ void move_write_offset(int delta) { this->smem_write_offset_ += delta; } + + // Store to the tile in shared memory. + template + inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { + this->store(data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(const void *(&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. - typename Cta_tile, + typename Cta_tile, // The layout of the tile. - typename Layout, + typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. int BUFFERS_PER_TILE = 1, // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_a { -}; + bool USE_PREDICATES = true> +struct Smem_tile_a {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int MMAS_K, int MMAS_K_WITH_PADDING > +template struct Compute_reset_mask { - // The potential mask. - enum { HALF = MMAS_K_WITH_PADDING / 2 }; - // The remainder. - enum { MOD = MMAS_K % HALF }; - // The final value. - enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; + // The potential mask. + enum { HALF = MMAS_K_WITH_PADDING / 2 }; + // The remainder. + enum { MOD = MMAS_K % HALF }; + // The final value. + enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int MMAS_K_WITH_PADDING > +template struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { - enum { VALUE = 0 }; + enum { VALUE = 0 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int MMAS_K > +template struct Compute_reset_mask { - enum { VALUE = MMAS_K - 1 }; + enum { VALUE = MMAS_K - 1 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template struct Rows_per_xor_pattern_a { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a { -}; +template +struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. @@ -358,179 +334,155 @@ template< // The number of buffers per tile. int BUFFERS_PER_TILE, // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE -> -struct Smem_tile_row_a : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_a; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - - // The row and column read by the thread. - int smem_read_row = (tidx & 0x0f); - int smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE> +struct Smem_tile_row_a : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); + static_assert(Base::ROWS_PER_XOR_PATTERN == 8); + + // The row and column read by the thread. + int smem_read_row = (tidx & 0x0f); + int smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } + } - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; } - // Load from shared memory. - inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { - #pragma unroll - for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - - // Store the value into the fragment. - a[mi].reg(0) = tmp.x; - a[mi].reg(1) = tmp.y; - a[mi].reg(2) = tmp.z; - a[mi].reg(3) = tmp.w; - } - - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; } + } - // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. - int BUFFERS_PER_TILE -> + int BUFFERS_PER_TILE> struct Smem_tile_a - : public Smem_tile_row_a { - // The base class. - using Base = Smem_tile_row_a; - - // Ctor. - inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) { - } + : public Smem_tile_row_a { + // The base class. + using Base = Smem_tile_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. - typename Cta_tile, + typename Cta_tile, // The layout of the tile. - typename Layout, + typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. int BUFFERS_PER_TILE = 1, // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_b { -}; + bool USE_PREDICATES = true> +struct Smem_tile_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template struct Rows_per_xor_pattern_b { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b { -}; +template +struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. @@ -538,164 +490,139 @@ template< // The number of buffers per tile. int BUFFERS_PER_TILE, // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE -> -struct Smem_tile_col_b : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b< Col>; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - - // The row and column read by the thread. - int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + - (tidx & 0x07) + - (tidx & 0x10) / 2; - int smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & 0x08) / 8; - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; - } + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE> +struct Smem_tile_col_b : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert(Base::ROWS_PER_XOR_PATTERN == 8); + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + (tidx & 0x07) + (tidx & 0x10) / 2; + int smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x08) / 8; + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } + } - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - - // Store the value into the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - } + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE > - : public Smem_tile_col_b { - - // The base class. - using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>; - - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_col_b { + // The base class. + using Base = Smem_tile_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> { -}; +template +struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. @@ -705,582 +632,552 @@ template< // How many rows to use for the XOR pattern to avoid bank conflicts? int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, // How many cols to use for the XOR pattern to avoid bank conflicts? - int COLS_PER_XOR_PATTERN_ = 1 -> -struct Smem_tile_row_b : public Smem_tile_without_skews { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b; - - // Can we use LDSM? No if the data type is 32-bit large. - enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; - // The number of elements per LDS. - enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) { - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(WARPS_K == 1); - static_assert(WARPS_M == 4 || WARPS_M == 8); - static_assert(WARPS_N == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - const int WARP_MASK_K = Warp_masks::K; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; - - // The row/col read by the thread. - int smem_read_row, smem_read_col; - - static_assert(USE_LDSMT); - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - - smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + - (tidx & 0x07) + (tidx & 0x08); - smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; - - // Fill zeroes for group conv - } + int COLS_PER_XOR_PATTERN_ = 1> +struct Smem_tile_row_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // Can we use LDSM? No if the data type is 32-bit large. + enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; + // The number of elements per LDS. + enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) { + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert(WARPS_K == 1); + static_assert(WARPS_M == 4 || WARPS_M == 8); + static_assert(WARPS_N == 1); - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + const int WARP_MASK_K = Warp_masks::K; - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row/col read by the thread. + int smem_read_row, smem_read_col; + + static_assert(USE_LDSMT); + static_assert(Base::ROWS_PER_XOR_PATTERN == 8); + + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08); + smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + + // Fill zeroes for group conv + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } } - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Prepare the offset. - int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; - if ( BYTES_PER_MMA_PER_CTA == 32 ) { - offset += this->smem_read_offset_; - } else if ( BYTES_PER_MMA_PER_CTA == 64 ) { - offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2; - } else { - offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA; - } - - // Load the data using LDSM.MT88.2. - uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; - uint4 tmp; - if( USE_LDSMT ) { - ldsmt(tmp, ptr); - } else { - lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW); - lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW); - lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW); - lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW); - } - - // Store those values in the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Prepare the offset. + int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; + if (BYTES_PER_MMA_PER_CTA == 32) { + offset += this->smem_read_offset_; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2; + } else { + offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA; + } + + // Load the data using LDSM.MT88.2. + uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; + uint4 tmp; + if (USE_LDSMT) { + ldsmt(tmp, ptr); + } else { + lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW); + lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW); + lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW); + lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW); + } + + // Store those values in the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. - int BUFFERS_PER_TILE -> + int BUFFERS_PER_TILE> struct Smem_tile_b - : public Smem_tile_row_b { - - // The base class. - using Base = Smem_tile_row_b; + : public Smem_tile_row_b { + // The base class. + using Base = Smem_tile_row_b; - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } + // Ctor. + inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Smem_tile_v : public fmha::Smem_tile_without_skews { - - // The base class. - using Base = Smem_tile_without_skews; - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The fragment. - using Fragment = Fragment_b< fmha::Col>; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) { - - // The row/col read by the thread. - int read_row, read_col; - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - - read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); - read_col = (tidx & 0x07); - read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + // The base class. + using Base = Smem_tile_without_skews; + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The fragment. + using Fragment = Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) { + // The row/col read by the thread. + int read_row, read_col; + + static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && + (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + + read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); + read_col = (tidx & 0x07); + read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by 16 * #warps row. - int row = ki * 16 * Cta_tile::WARPS_K; - - // Load the data using LDSM.MT88.2. - uint4 tmp; - fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else { - assert(false); // Not implemented! - } - } + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by 16 * #warps row. + int row = ki * 16 * Cta_tile::WARPS_K; + + // Load the data using LDSM.MT88.2. + uint4 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else { + assert(false); // Not implemented! + } } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_o { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - // The accumulators. - using Data_type = typename Accumulator::Data_type; - - // The size of each element. - enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; - // The size of each STS. - enum { BYTES_PER_STS = 8 }; - // The size of each row in shared memory. - enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT }; - - // The size of each LDS. - enum { BYTES_PER_LDS = 16 }; - enum { THREADS_PER_ROW = 16 }; - - // The number of rows. - enum { ROWS = Cta_tile::M }; - // The number of "rows" to process per loop iteration (in the "epilogue"). - enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; - // The number of outer loops. - enum { LOOPS = ROWS / ROWS_PER_LOOP }; - // Make sure it matches our expectations. - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // The number of rows loaded per LDS. - enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Do we have to guard against partial writes/reads. - enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; - // The total number of LDS per loop. - enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; - - // The amount of shared memory. - enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; - - // The write pointer. - uint32_t smem_write_, smem_read_; - // Is the thread active for the last LDS of the series? - int is_active_for_last_lds_; - - static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // Ctor. - inline __device__ Smem_tile_o(void *smem, int tidx) { - - // Get a 32-bit value for the shared memory address. - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - - int write_row = (tidx & 0x1c) / 4; - int write_col = (tidx); - - // Assemble the write pointer. - smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - - // The element read by each thread. - int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - - // Take the XOR pattern into account for the column. - read_col ^= 2 * (read_row & 0x7); - - // Assemble the read pointer. - this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - - // Is that thread active on the last LDS? - if( HAS_INCOMPLETE_LDS ) { - this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; - } + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The accumulators. + using Data_type = typename Accumulator::Data_type; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; + // The size of each STS. + enum { BYTES_PER_STS = 8 }; + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT }; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + enum { THREADS_PER_ROW = 16 }; + + // The number of rows. + enum { ROWS = Cta_tile::M }; + // The number of "rows" to process per loop iteration (in the "epilogue"). + enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; + // The number of outer loops. + enum { LOOPS = ROWS / ROWS_PER_LOOP }; + // Make sure it matches our expectations. + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; + + // The write pointer. + uint32_t smem_write_, smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; + + static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // Ctor. + inline __device__ Smem_tile_o(void *smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && + (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + + int write_row = (tidx & 0x1c) / 4; + int write_col = (tidx); + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= 2 * (read_row & 0x7); + + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; } + } - // Load the output fragments. - inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { - #pragma unroll - for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) { - - // Load the elements before the reduction (split-K). - uint4 tmp[Cta_tile::WARPS_K]; - #pragma unroll - for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) { - int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; - if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) { - fmha::lds(tmp[jj], this->smem_read_ + imm); - } - } - - // Perform the reduction. - out[ii] = tmp[0]; - #pragma unroll - for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) { - out[ii] = fmha::fadd4(out[ii], tmp[jj]); - } + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_)) { + fmha::lds(tmp[jj], this->smem_read_ + imm); } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + } } - // Store the accumulators. - template - inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { - enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - - // The number of MMAs that are stored per loop iteration. - enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; - - // Store 1st column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); - - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - - // Swizzle the write pointer using a XOR of 16B. - this->smem_write_ ^= 32; - - // Store 2nd column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - - // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. - this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; - } + } + // Store the accumulators. + template + inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + +// Store 2nd column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Smem_tile_mma { - - using Mma_tile = fmha::Hmma_tile; - using Fragment = fmha::Fragment_a; - - enum { COLS = Cta_tile::N }; - enum { BYTES_PER_ELT = 2 }; - enum { BYTES_PER_STS = 4 }; - enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO - enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - enum { WARPS_K = Cta_tile::WARPS_K }; - - static_assert(WARPS_K == 1); - inline __device__ Smem_tile_mma(char *smem, int tidx) { - smem_ = __nvvm_get_smem_pointer(smem); - - int write_col, write_row; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { - write_row = (tidx & 0x1c) / 4; - write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); - } else { - write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; - write_col = (tidx & 0x03); - } - write_col ^= (write_row & 0x07) * 4; - - write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + using Mma_tile = fmha::Hmma_tile; + using Fragment = fmha::Fragment_a; + + enum { COLS = Cta_tile::N }; + enum { BYTES_PER_ELT = 2 }; + enum { BYTES_PER_STS = 4 }; + enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO + enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K == 1); + inline __device__ Smem_tile_mma(char *smem, int tidx) { + smem_ = __nvvm_get_smem_pointer(smem); + + int write_col, write_row; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + if (WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); + } else { + write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x03); } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - static_assert(COLS == Cta_tile::N); - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } + write_col ^= (write_row & 0x07) * 4; + + write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + static_assert(COLS == Cta_tile::N); + for (int mi = 0; mi < M; mi++) { + for (int ni = 0; ni < N; ni++) { + size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } } + } - uint32_t smem_; - uint32_t write_offset_; - uint32_t warp_m; - uint32_t warp_n; - uint32_t lane; + uint32_t smem_; + uint32_t write_offset_; + uint32_t warp_m; + uint32_t warp_n; + uint32_t lane; }; -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> +template > struct Smem_tile_mma_transposed : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + using Fragment = typename Base::Fragment; + inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - using Fragment = typename Base::Fragment; - inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { - - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - int read_row, read_col; - read_row = (tidx & 0x0f); - read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; - - read_col ^= (read_row & 0x07); - read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - template - inline __device__ void load(Fragment (&frag)[M][N]) { - static_assert(Base::COLS == Cta_tile::N); - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); - frag[mi][ni].reg(0) = dst.x; - frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! - frag[mi][ni].reg(2) = dst.y; - frag[mi][ni].reg(3) = dst.w; - } - } + int read_row, read_col; + read_row = (tidx & 0x0f); + read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; + + read_col ^= (read_row & 0x07); + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + template + inline __device__ void load(Fragment (&frag)[M][N]) { + static_assert(Base::COLS == Cta_tile::N); + for (int mi = 0; mi < M; mi++) { + for (int ni = 0; ni < N; ni++) { + size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + fmha::ldsmt(dst, this->smem_ + offset); + frag[mi][ni].reg(0) = dst.x; + frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! + frag[mi][ni].reg(2) = dst.y; + frag[mi][ni].reg(3) = dst.w; + } } + } - uint32_t read_offset_; + uint32_t read_offset_; }; -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> +template > struct Smem_tile_mma_epilogue : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; - static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); - enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; - static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; - static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - - using Acc = fmha::Fragment_accumulator; - - inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { - const int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - read_col ^= (read_row & 0x07); - read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; + static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + + using Acc = fmha::Fragment_accumulator; + + inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { + const int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + read_col ^= (read_row & 0x07); + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void load(uint4 (&data)[NUM_LDS]) { + for (int ii = 0; ii < NUM_LDS; ii++) { + size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + fmha::lds(data[ii], this->smem_ + offset); } + } - inline __device__ void load(uint4 (&data)[NUM_LDS]) { - for( int ii = 0; ii < NUM_LDS; ii++ ) { - size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - fmha::lds(data[ii], this->smem_ + offset); - } - } - - template - inline __device__ void store(const Acc (&acc)[M][N]){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // 1st row - 4 elements per row. - float tmp00 = acc[mi][ni].elt(0); - float tmp01 = acc[mi][ni].elt(1); - float tmp02 = acc[mi][ni].elt(4); - float tmp03 = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - float tmp10 = acc[mi][ni].elt(2); - float tmp11 = acc[mi][ni].elt(3); - float tmp12 = acc[mi][ni].elt(6); - float tmp13 = acc[mi][ni].elt(7); - - uint32_t x = fmha::float2_to_half2(tmp00, tmp01); - uint32_t y = fmha::float2_to_half2(tmp02, tmp03); - uint32_t z = fmha::float2_to_half2(tmp10, tmp11); - uint32_t w = fmha::float2_to_half2(tmp12, tmp13); - - size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); - } - } + template + inline __device__ void store(const Acc (&acc)[M][N]) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // 1st row - 4 elements per row. + float tmp00 = acc[mi][ni].elt(0); + float tmp01 = acc[mi][ni].elt(1); + float tmp02 = acc[mi][ni].elt(4); + float tmp03 = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + float tmp10 = acc[mi][ni].elt(2); + float tmp11 = acc[mi][ni].elt(3); + float tmp12 = acc[mi][ni].elt(6); + float tmp13 = acc[mi][ni].elt(7); + + uint32_t x = fmha::float2_to_half2(tmp00, tmp01); + uint32_t y = fmha::float2_to_half2(tmp02, tmp03); + uint32_t z = fmha::float2_to_half2(tmp10, tmp11); + uint32_t w = fmha::float2_to_half2(tmp12, tmp13); + + size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); + } } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + for (int mi = 0; mi < M; mi++) { + for (int ni = 0; ni < N; ni++) { + size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } } + } - uint32_t read_offset_; + uint32_t read_offset_; }; } // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/softmax.h b/apex/contrib/csrc/fmha/src/fmha/softmax.h index 153f42d57..ffe91fad9 100644 --- a/apex/contrib/csrc/fmha/src/fmha/softmax.h +++ b/apex/contrib/csrc/fmha/src/fmha/softmax.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -32,362 +32,350 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Sum_ { - enum { IS_SUM = 1 }; - static inline __device__ float apply(float x, float y) { - return x + y; - } + enum { IS_SUM = 1 }; + static inline __device__ float apply(float x, float y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Max_ { - enum { IS_SUM = 0 }; - static inline __device__ float apply(float x, float y) { - return x > y ? x : y; - } + enum { IS_SUM = 0 }; + static inline __device__ float apply(float x, float y) { return x > y ? x : y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float apply_exp_(float x, float max) { - return __expf(x - max); -} +inline __device__ float apply_exp_(float x, float max) { return __expf(x - max); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template struct ReadType {}; -template<> struct ReadType<4> { using T = float;}; -template<> struct ReadType<8> { using T = float2;}; +template +struct ReadType {}; +template <> +struct ReadType<4> { + using T = float; +}; +template <> +struct ReadType<8> { + using T = float2; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_reduce { - // Helper class to distribute MMA tiles reduced over rows per warp over quads. - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - enum { MMAS_M = Mma_tile::MMAS_M }; - enum { MMAS_N = Mma_tile::MMAS_N }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - - - static constexpr int ROWS = WARPS_M * MMAS_M * 16; - static constexpr int COLS = WARPS_N; - static_assert(COLS == 4 || COLS == 8); - static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; - static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); - static constexpr int ELTS_PER_TILE = ROWS * COLS; - - static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; - static_assert(THREADS_PER_GROUP == 16); // DEBUG - static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; - static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; - static_assert(LOOPS == 1); - - using read_t = typename ReadType::T; - - __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { - - int lane = tidx % 32; - int warp = tidx / 32; - - int warp_m = warp % WARPS_M; - int warp_n = warp / WARPS_M; - - qid_ = lane % 4; - int qp = lane / 4; - - // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. - // This won't affect reading as we assume commutative reduction ops. - const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); - smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; - smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; - + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + enum { MMAS_N = Mma_tile::MMAS_N }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static_assert(COLS == 4 || COLS == 8); + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; + static_assert(THREADS_PER_GROUP == 16); // DEBUG + static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; + static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; + static_assert(LOOPS == 1); + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if (qid_ == 0) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } } - - __device__ inline void store(float (&frag)[2 * MMAS_M]) { - if( qid_ == 0 ) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * WARPS_N; - smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; - smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; - } - } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; } + } - __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * 4; - frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; - frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; - } - } - - int qid_; - float *smem_write_; - read_t *smem_read_; - + int qid_; + float *smem_write_; + read_t *smem_read_; }; - -template +template struct Softmax_base { - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - enum { MMAS_M = Mma_tile::MMAS_M }; - enum { MMAS_N = Mma_tile::MMAS_N }; - - // The number of groups of warp such that we have at most 4 warps writing consecutive elements. - enum { GROUPS = fmha::Div_up::VALUE }; - // The number of elements that we are going to store per row. - enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; - // The number of rows. - enum { ROWS = Cta_tile::M * GROUPS }; - // The total number of elements. - enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; - - // Ctor. - template - inline __device__ Softmax_base(const Params ¶ms, void *smem, int bidb, int tidx) - : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), - smem_(reinterpret_cast(smem)), tidx_(tidx) { - - // Move to the 1st mask loaded by the thread+ tidx; - // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); - - // Extract the position in the warp. - int warp = tidx / Cta_tile::THREADS_PER_WARP; - int lane = tidx % Cta_tile::THREADS_PER_WARP; - - // Decompose the warp index into M and N. - int warp_m = warp % Cta_tile::WARPS_M; - int warp_n = warp / Cta_tile::WARPS_M; - - // Decompose the warp-n index into group/position-inside-the-group. - int warp_g = warp_n / ELEMENTS_PER_ROW; - int warp_i = warp_n % ELEMENTS_PER_ROW; - - // The location written by the threads. - int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; - int write_col = warp_i; - - // Assemble the write pointer. - smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; - - // Assemble the read pointer. - smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; - } - - template - inline __device__ void apply_mask(const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ii = 0; ii < 2; ++ii ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - #pragma unroll - for( int jj = 0; jj < 4; ++jj ) { - if( !mask.is_valid(mi, ni, ii, jj) ) { - elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; - } - } - } + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + enum { MMAS_N = Mma_tile::MMAS_N }; + + // The number of groups of warp such that we have at most 4 warps writing consecutive elements. + enum { GROUPS = fmha::Div_up::VALUE }; + // The number of elements that we are going to store per row. + enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; + // The number of rows. + enum { ROWS = Cta_tile::M * GROUPS }; + // The total number of elements. + enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; + + // Ctor. + template + inline __device__ Softmax_base(const Params ¶ms, void *smem, int bidb, int tidx) + : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + smem_(reinterpret_cast(smem)), + tidx_(tidx) { + // Move to the 1st mask loaded by the thread+ tidx; + // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + template + inline __device__ void apply_mask(const Mask &mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; } + } } + } } - - // Apply the exp to all the elements. - inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); - } - } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + } } - - // Scale all the elements. - inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { - // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. - float inv_sum[MMAS_M * 2]; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; - } - - // Update the values. - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] *= inv_sum[mi]; - } - } + } + + // Scale all the elements. + inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * 2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; } - // The pointer to the mask. - const char *packed_mask_ptr_; - // Shared memory for the CTA-wide reduction. - float *smem_, *smem_write_, *smem_read_; - // The current thread index. - int tidx_; - // The elements. - float elt_[MMAS_M * 2][MMAS_N * 4]; +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // The pointer to the mask. + const char *packed_mask_ptr_; + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Softmax : public Softmax_base { - - // The base class. - using Base = Softmax_base; - // The fragment. - using Fragment_a = fmha::Fragment_a; - - static_assert(Fragment_a::NUM_REGS == 4); - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - // The MMAs. - enum { MMAS_M = Base::MMAS_M }; - enum { MMAS_N = Base::MMAS_N }; - - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - using Accumulator_out = Fragment; - static_assert(Accumulator_out::NUM_REGS == 4); - - static_assert(std::is_same::value); - - using Smem_tile_red = Smem_tile_reduce; - static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); - // Ctor. - template - inline __device__ Softmax(const Params ¶ms, void *smem, int bidb, int tidx) - : Base(params, smem, bidb, tidx) - , params_scale_bmm1_(params.scale_bmm1) - , smem_sum_(static_cast(smem), tidx) - , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { + // The base class. + using Base = Softmax_base; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + static_assert(Fragment_a::NUM_REGS == 4); + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + using Accumulator_out = Fragment; + static_assert(Accumulator_out::NUM_REGS == 4); + + static_assert(std::is_same::value); + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + // Ctor. + template + inline __device__ Softmax(const Params ¶ms, void *smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + smem_sum_(static_cast(smem), tidx), + smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } } - - // Pack the data to a fragment for the next GEMM. - template - inline __device__ void pack(Fragment_a (&dst)[K][M]) const { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < K; ++ki ) { - - // 1st row - 4 elements per row. - float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; - - // Pack to 4 registers. - dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); - dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); - dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); - dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); - } - } + } + + // Scale FP32 fragments + inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { + const float scalef = reinterpret_cast(this->params_scale_bmm1_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; + } } - - // Scale FP32 fragments - inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - const float scalef = reinterpret_cast(this->params_scale_bmm1_); - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; - } - } + } + // Scale FP32 fragments + inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); + } } - // Scale FP32 fragments - inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); - } - } + } + + template + __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red &smem_red) { + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + frag[mi] = this->elt_[mi][0]; + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } } + quad_reduce(frag, frag, op); + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + quad_allreduce(frag, tmp, op); + } - template - __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { - for( int mi = 0; mi < 2 * MMAS_M; mi++ ) { - frag[mi] = this->elt_[mi][0]; - for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { - frag[mi] = op(frag[mi], this->elt_[mi][ni]); - } - } - quad_reduce(frag, frag, op); - - smem_red.store(frag); - __syncthreads(); - typename Smem_tile_red::read_t tmp[2 * MMAS_M]; - smem_red.load(tmp); - - quad_allreduce(frag, tmp, op); - } - - __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ - MaxOp max; - reduce_(frag, max, smem_max_); - } - - __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ - SumOp sum; - reduce_(frag, sum, smem_sum_); - } + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) { + MaxOp max; + reduce_(frag, max, smem_max_); + } + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) { + SumOp sum; + reduce_(frag, sum, smem_sum_); + } - const uint32_t params_scale_bmm1_; - Smem_tile_red smem_max_; - Smem_tile_red smem_sum_; + const uint32_t params_scale_bmm1_; + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha/utils.h b/apex/contrib/csrc/fmha/src/fmha/utils.h index bedba0eff..31dbf4aaf 100644 --- a/apex/contrib/csrc/fmha/src/fmha/utils.h +++ b/apex/contrib/csrc/fmha/src/fmha/utils.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -39,521 +39,580 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Row {}; +struct Row {}; struct Col {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int M, bool = (M & (M-1)) == 0 > -struct Next_power_of_two { -}; +template +struct Next_power_of_two {}; -template< int M > -struct Next_power_of_two< M, true > { enum { VALUE = M }; }; -template<> -struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; }; -template<> -struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; }; -template<> -struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; }; -template<> -struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, bool = (N & (N-1)) == 0 > -struct Prev_power_of_two { +template +struct Next_power_of_two { + enum { VALUE = M }; +}; +template <> +struct Next_power_of_two<3, false> { + enum { VALUE = 4 }; +}; +template <> +struct Next_power_of_two<5, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<6, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<7, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<9, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<10, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<11, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<12, false> { + enum { VALUE = 16 }; }; +template <> +struct Next_power_of_two<13, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<14, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<15, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<24, false> { + enum { VALUE = 32 }; +}; +template <> +struct Next_power_of_two<48, false> { + enum { VALUE = 64 }; +}; +template <> +struct Next_power_of_two<80, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<96, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<112, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<144, false> { + enum { VALUE = 256 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -struct Prev_power_of_two< N, true > { enum { VALUE = N }; }; -template<> -struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; }; -template<> -struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; }; +template +struct Prev_power_of_two {}; + +template +struct Prev_power_of_two { + enum { VALUE = N }; +}; +template <> +struct Prev_power_of_two<3, false> { + enum { VALUE = 2 }; +}; +template <> +struct Prev_power_of_two<5, false> { + enum { VALUE = 4 }; +}; +template <> +struct Prev_power_of_two<6, false> { + enum { VALUE = 4 }; +}; +template <> +struct Prev_power_of_two<7, false> { + enum { VALUE = 4 }; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int M, int N > +template struct Div_up { - enum { VALUE = (M + N-1) / N }; + enum { VALUE = (M + N - 1) / N }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int A, int B > +template struct Max { - enum { VALUE = A >= B ? A : B }; + enum { VALUE = A >= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int A, int B, int C > +template struct Max_3 { - enum { VALUE = Max::VALUE, C>::VALUE }; + enum { VALUE = Max::VALUE, C>::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int A, int B > +template struct Min { - enum { VALUE = A <= B ? A : B }; + enum { VALUE = A <= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int SIZE_IN_BYTES > -struct Uint_from_size_in_bytes { -}; +template +struct Uint_from_size_in_bytes {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Uint_from_size_in_bytes<1> { - using Type = uint8_t; + using Type = uint8_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Uint_from_size_in_bytes<2> { - using Type = uint16_t; + using Type = uint16_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Uint_from_size_in_bytes<4> { - using Type = uint32_t; + using Type = uint32_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Uint_from_size_in_bytes<8> { - using Type = uint2; + using Type = uint2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Uint_from_size_in_bytes<16> { - using Type = uint4; + using Type = uint4; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int WARPS_M, int WARPS_N, int WARPS_K > -struct Warp_masks { +template +struct Warp_masks {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Warp_masks<8, 1, 1> { + enum { M = 0xe0, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<4, 2, 1> { + enum { M = 0x60, N = 0x80, K = 0x00 }; +}; +template <> +struct Warp_masks<4, 1, 2> { + enum { M = 0x60, N = 0x00, K = 0x80 }; +}; +template <> +struct Warp_masks<4, 1, 1> { + enum { M = 0x60, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 4, 1> { + enum { M = 0x20, N = 0xc0, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 2, 2> { + enum { M = 0x20, N = 0x40, K = 0x80 }; +}; +template <> +struct Warp_masks<2, 2, 1> { + enum { M = 0x20, N = 0x40, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 1, 2> { + enum { M = 0x20, N = 0x00, K = 0x40 }; +}; +template <> +struct Warp_masks<2, 1, 1> { + enum { M = 0x20, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 8, 1> { + enum { M = 0x00, N = 0xe0, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 4, 2> { + enum { M = 0x00, N = 0x60, K = 0x80 }; +}; +template <> +struct Warp_masks<1, 4, 1> { + enum { M = 0x00, N = 0x60, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 2, 2> { + enum { M = 0x00, N = 0x20, K = 0x40 }; +}; +template <> +struct Warp_masks<1, 2, 1> { + enum { M = 0x00, N = 0x20, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 1, 4> { + enum { M = 0x00, N = 0x00, K = 0x60 }; +}; +template <> +struct Warp_masks<1, 1, 2> { + enum { M = 0x00, N = 0x00, K = 0x20 }; +}; +template <> +struct Warp_masks<1, 1, 1> { + enum { M = 0x00, N = 0x00, K = 0x00 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> -struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; -template<> -struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; -template<> -struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; -template<> -struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; -template<> -struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; -template<> -struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; -template<> -struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; -template<> -struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; -template<> -struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; -template<> -struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; -template<> -struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; -template<> -struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; -template<> -struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; -template<> -struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename T > +template inline __device__ __host__ T div_up(T m, T n) { - return (m + n-1) / n; + return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int clz(int x) { - for( int i = 31; i >= 0; --i ) { - if( (1 << i) & x ) { - return 31 - i; - } + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) { + return 31 - i; } - return 32; + } + return 32; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int find_log_2(int x, bool round_up = false) { - int a = 31 - clz(x); - if( round_up ) { - a += (x & (x-1)) ? 1 : 0; - } - return a; + int a = 31 - clz(x); + if (round_up) { + a += (x & (x - 1)) ? 1 : 0; + } + return a; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); - return c; + uint32_t c; + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { - uint2 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - return c; + uint2 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint4 a, uint4 b) { - uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); - return c; + uint4 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { - uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); - return c; + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { - uint32_t res; + uint32_t res; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); #else - const uint32_t zero = 0u; - asm volatile( \ - "{\n" \ - "\t .reg .f16x2 sela;\n" \ - "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" - "}\n" : "=r"(res) : "r"(x), "r"(zero)); + const uint32_t zero = 0u; + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); #endif - return res; + return res; } static inline __device__ uint32_t habs2(uint32_t x) { - uint32_t res; - asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); - return res; + uint32_t res; + asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); + return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // -template< typename T > +template static inline __device__ T clamp(T x, T lb, T ub) { - return x < lb ? lb : (x > ub ? ub : x); + return x < lb ? lb : (x > ub ? ub : x); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t clamp_to_zero(uint16_t x) { - uint16_t mask; - asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); - return mask & x; + uint16_t mask; + asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); + return mask & x; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t float_to_half(float f) { - uint16_t h; - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); - return h; + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(float a, float b) { - uint32_t c; + uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); #else - uint16_t lo = float_to_half(a); - uint16_t hi = float_to_half(b); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); + uint16_t lo = float_to_half(a); + uint16_t hi = float_to_half(b); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); #endif - return c; + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t float_to_half2(float a) { - return float2_to_half2(a,a); -} +static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t float2_to_half2(const float2 &f) { - return float2_to_half2(f.x, f.y); -} +static inline __device__ uint32_t float2_to_half2(const float2 &f) { return float2_to_half2(f.x, f.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { - uint2 d; - d.x = float2_to_half2(x, y); - d.y = float2_to_half2(z, w); - return d; + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; + uint32_t d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else - d = hrelu2(hfma2(a, b, c)); + d = hrelu2(hfma2(a, b, c)); #endif - return d; + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h0_h0(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" - : "=r"(y) : "r"(x)); - return y; + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x)); + return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float h0_to_float(uint32_t h2) { - float f; - asm volatile("{\n" \ - ".reg .f16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %1;\n" \ - "cvt.f32.f16 %0, lo;\n" \ - "}\n" : "=f"(f) : "r"(h2)); - return f; + float f; + asm volatile( + "{\n" + ".reg .f16 lo, hi;\n" + "mov.b32 {lo, hi}, %1;\n" + "cvt.f32.f16 %0, lo;\n" + "}\n" + : "=f"(f) + : "r"(h2)); + return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h1_h1(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" - : "=r"(y) : "r"(x)); - return y; + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x)); + return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; + uint16_t d; + asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { - return hadd2(a, b); -} +static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd4(uint2 a, uint2 b) { - uint2 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - return c; + uint2 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint2 hadd(uint2 a, uint2 b) { - return hadd4(a, b); -} +static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd8(uint4 a, uint4 b) { - uint4 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - c.z = hadd2(a.z, b.z); - c.w = hadd2(a.w, b.w); - return c; + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 fadd4(uint4 a, uint4 b) { - float4 c; - c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); - c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); - c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); - c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); - return reinterpret_cast(c); + float4 c; + c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); + c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); + c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); + c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); + return reinterpret_cast(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint4 hadd(uint4 a, uint4 b) { - return hadd8(a, b); -} +static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float2 half2_to_float2(uint32_t x) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); - return make_float2(half_to_float(lo), half_to_float(hi)); + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); + return make_float2(half_to_float(lo), half_to_float(hi)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) { - float2 tmp = half2_to_float2(h); - x = tmp.x; - y = tmp.y; + float2 tmp = half2_to_float2(h); + x = tmp.x; + y = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { - uint16_t d; - asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); - return d; + uint16_t d; + asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; + uint16_t d; + asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ float sigmoid(float x) { - return 1.f / (1.f + expf(-x)); -} +static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void clear(uint16_t &dst) { - dst = uint16_t(0); -} +inline __device__ void clear(uint16_t &dst) { dst = uint16_t(0); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void clear(uint32_t &dst) { - dst = 0u; -} +inline __device__ void clear(uint32_t &dst) { dst = 0u; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void clear(uint2 &dst) { - dst = make_uint2(0u, 0u); -} +inline __device__ void clear(uint2 &dst) { dst = make_uint2(0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void clear(uint4 &dst) { - dst = make_uint4(0u, 0u, 0u, 0u); -} +inline __device__ void clear(uint4 &dst) { dst = make_uint4(0u, 0u, 0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -562,85 +621,81 @@ inline __device__ void clear(uint4 &dst) { //////////////////////////////////////////////////////////////////////////////////////////////////// enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; - //////////////////////////////////////////////////////////////////////////////////////////////////// // // G E N E R I C P R E D I C A T E D L D G S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M, typename Functor > +template inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) { - - // The number of complete bytes (where we use all the predicates in a byte). - enum { COMPLETE = N / PREDS_PER_BYTE }; - // Make sure we did allocate enough predicates. - static_assert(Div_up::VALUE <= M, ""); - // The remainder. - enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; - // Make sure we got the math right and the remainder is between 0 and 3. - static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); - // The mask to extract the predicates. - enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; - - // Clear the fetch registers. - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - fct.clear(ii); + // The number of complete bytes (where we use all the predicates in a byte). + enum { COMPLETE = N / PREDS_PER_BYTE }; + // Make sure we did allocate enough predicates. + static_assert(Div_up::VALUE <= M, ""); + // The remainder. + enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + +// Clear the fetch registers. +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + fct.clear(ii); + } + + // Run complete steps. + bool p[PREDS_PER_BYTE]; +#pragma unroll + for (int ii = 0; ii < COMPLETE; ++ii) { + // The predicate. + uint32_t reg = preds[ii / BYTES_PER_REG]; + +// Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; } - // Run complete steps. - bool p[PREDS_PER_BYTE]; - #pragma unroll - for( int ii = 0; ii < COMPLETE; ++ii ) { - - // The predicate. - uint32_t reg = preds[ii / BYTES_PER_REG]; - - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } - - // Issue the loads. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); - } +// Issue the loads. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); } + } - // Skip the rest of the code if we do not have a remainder. - if( REMAINDER > 0 ) { - - // The mask to extract the predicates. - enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + // Skip the rest of the code if we do not have a remainder. + if (REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; - // The predicate register. - uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; + // The predicate register. + uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } +// Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } - // Issue the loads. - #pragma unroll - for( int ii = 0; ii < REMAINDER; ++ii ) { - fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); - } +// Issue the loads. +#pragma unroll + for (int ii = 0; ii < REMAINDER; ++ii) { + fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int M, typename Functor > +template inline __device__ void load_(Functor &fct, uint32_t preds) { - uint32_t tmp[1] = { preds }; - load_(fct, tmp); + uint32_t tmp[1] = {preds}; + load_(fct, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -649,102 +704,88 @@ inline __device__ void load_(Functor &fct, uint32_t preds) { // //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void ldg(uint8_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} +inline __device__ void ldg(uint8_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void ldg(uint16_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} +inline __device__ void ldg(uint16_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void ldg(uint32_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} +inline __device__ void ldg(uint32_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void ldg(uint2 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} +inline __device__ void ldg(uint2 &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void ldg(uint4 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} +inline __device__ void ldg(uint4 &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Data_type, int N > +template struct Ldg_functor { - // Ctor. - inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) - : fetch_(fetch), ptrs_(ptrs) { - } + // Ctor. + inline __device__ Ldg_functor(Data_type (&fetch)[N], const void *(&ptrs)[N]) : fetch_(fetch), ptrs_(ptrs) {} - // Clear the element. - inline __device__ void clear(int ii) { - fmha::clear(fetch_[ii]); - } + // Clear the element. + inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } - // Trigger the loads. - inline __device__ void load(int ii, bool p) { - if( p ) { - ldg(fetch_[ii], ptrs_[ii]); - } + // Trigger the loads. + inline __device__ void load(int ii, bool p) { + if (p) { + ldg(fetch_[ii], ptrs_[ii]); } + } - // The fetch registers. - Data_type (&fetch_)[N]; - // The pointers. - const void* (&ptrs_)[N]; + // The fetch registers. + Data_type (&fetch_)[N]; + // The pointers. + const void *(&ptrs_)[N]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Data_type, int N, int M > -inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - Ldg_functor fct(fetch, ptrs); - load_(fct, preds); +template +inline __device__ void ldg_(Data_type (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + Ldg_functor fct(fetch, ptrs); + load_(fct, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M > -inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); +template +inline __device__ void ldg(uint8_t (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M > -inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); +template +inline __device__ void ldg(uint16_t (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M > -inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); +template +inline __device__ void ldg(uint32_t (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M > -inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); +template +inline __device__ void ldg(uint2 (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N, int M > -inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); +template +inline __device__ void ldg(uint4 (&fetch)[N], const void *(&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -754,30 +795,27 @@ inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t ( //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint16_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint32_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint2 &dst, uint32_t ptr) { - asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); + asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint4 &dst, uint32_t ptr) { - asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x) - , "=r"(dst.y) - , "=r"(dst.z) - , "=r"(dst.w) - : "r"(ptr)); + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -788,8 +826,7 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) { inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } @@ -797,8 +834,7 @@ inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } @@ -806,8 +842,7 @@ inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } @@ -815,8 +850,9 @@ inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); #endif } @@ -824,8 +860,9 @@ inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); #endif } @@ -833,8 +870,9 @@ inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); #endif } @@ -844,33 +882,23 @@ inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { // //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void stg(void *ptr, uint8_t val) { - *reinterpret_cast(ptr) = val; -} +inline __device__ void stg(void *ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void stg(void *ptr, uint16_t val) { - *reinterpret_cast(ptr) = val; -} +inline __device__ void stg(void *ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void stg(void *ptr, uint32_t val) { - *reinterpret_cast(ptr) = val; -} +inline __device__ void stg(void *ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void stg(void *ptr, uint2 val) { - *reinterpret_cast(ptr) = val; -} +inline __device__ void stg(void *ptr, uint2 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void stg(void *ptr, uint4 val) { - *reinterpret_cast(ptr) = val; -} +inline __device__ void stg(void *ptr, uint4 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -879,158 +907,150 @@ inline __device__ void stg(void *ptr, uint4 val) { //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint16_t val) { - asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); + asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint32_t val) { - asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); + asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint2 val) { - asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y)); + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint4 val) { - asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y) - , "r"(val.z) - , "r"(val.w)); + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Data_type, int N > +template inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - sts(ptrs[ii], data[ii]); - } +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + sts(ptrs[ii], data[ii]); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { - sts_(ptrs, data); + sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { - sts_(ptrs, data); + sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { - sts_(ptrs, data); + sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { - sts_(ptrs, data); + sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } + __device__ inline T operator()(T const &x, T const &y) { return x > y ? x : y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } + __device__ inline T operator()(T const &x, T const &y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + template + static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; -} + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); - } +template +__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template __device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_reduce(dst, tmp, op); + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template __device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = Allreduce<4>::run(dst[mi], op); - } +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_allreduce(dst, tmp, op); + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu index 517a5b758..8feac023e 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,30 +31,29 @@ using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); + fmha::compute_dv_1xN(params); + fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 128 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_128_64_sm80_kernel<<>>(params); + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * 128 * 2); + static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + + constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; + constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; + constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_128_64_sm80_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(params.h, params.b); + fmha_dgrad_fp16_128_64_sm80_kernel<<>>(params); } diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu index ac22a1629..8459f84ef 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,30 +31,29 @@ using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); + fmha::compute_dv_1xN(params); + fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 256 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_256_64_sm80_kernel<<>>(params); + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * 256 * 2); + static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + + constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; + constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; + constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_256_64_sm80_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(params.h, params.b); + fmha_dgrad_fp16_256_64_sm80_kernel<<>>(params); } diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu index 7081438e9..3784b2c6d 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,30 +31,29 @@ using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); + fmha::compute_dv_1xN(params); + fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 384 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_384_64_sm80_kernel<<>>(params); + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * 384 * 2); + static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + + constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; + constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; + constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_384_64_sm80_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(params.h, params.b); + fmha_dgrad_fp16_384_64_sm80_kernel<<>>(params); } diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu index 735006cc2..071d09628 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -32,74 +32,72 @@ using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); + fmha::compute_dv_1xN(params); + fmha::compute_dq_dk_1xN(params); } -template -__global__ -void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){ - fmha::compute_dv_1xN_nl(params); - fmha::compute_dq_dk_1xN_nl(params); +template +__global__ void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params) { + fmha::compute_dv_1xN_nl(params); + fmha::compute_dq_dk_1xN_nl(params); } void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 512 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_512_64_sm80_kernel<<>>(params); + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * 512 * 2); + static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + + constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; + constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; + constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_512_64_sm80_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(params.h, params.b); + fmha_dgrad_fp16_512_64_sm80_kernel<<>>(params); } -void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream) { +void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, + cudaStream_t stream) { + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * 512 * 2); + static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 512 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; + constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; + constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); + auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - - if( num_chunks == 2 ) { - kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - }else if( num_chunks == 3 ) { - kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>; - } else { - assert(false && "Unsupperted number of chunks"); - } + if (num_chunks == 2) { + kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; + } else if (num_chunks == 3) { + kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>; + } else { + assert(false && "Unsupperted number of chunks"); + } - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - dim3 grid(params.h, params.b, num_chunks); + dim3 grid(params.h, params.b, num_chunks); - kernel<<>>(params); + kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h index 3c4b81742..b441cc3e0 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,530 +27,524 @@ #pragma once -#include "fmha_kernel.h" -#include #include +#include + +#include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dv_1xN(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dv = - fmha::Cta_tile_extd; - - static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); - static_assert(Cta_tile_dv::N == 64); - static_assert(Cta_tile_dv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - // using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - using Smem_tile_q = fmha::Smem_tile_a; - // The shared memory tile to reload Q as fragment b. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dV. - using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); - static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load the fragments for Q. - typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; - smem_q.load(frag_q[0], 0); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - // Load the fragments for K. We keep the data in registers during the entire kernel. - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - Gmem_tile_s gmem_s(params, binfo, tidx); - - // Create the object to do the softmax. - using Softmax = fmha::Softmax; - Softmax softmax( - params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); - - enum { THREADS_PER_ROW = 32 }; - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - - enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; - // Load over the entire sequence length. - for( int l = 0; l < STEPS; l++ ) { - const int loop = l * Cta_tile_p::M; - if( loop >= binfo.actual_seqlen ) - break; - - // Load S - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_q.load(frag_q[ki & 1], ki); - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - - // Store s * dmask to smem for transpose - smem_s.store(s_regs); - - // Declare the accumulators for the 1st gemm. - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe - if( l < STEPS - 1) { - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_dv = + fmha::Cta_tile_extd; + + static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); + static_assert(Cta_tile_dv::N == 64); + static_assert(Cta_tile_dv::K == 16); + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_dv = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + // The shared memory tile to swizzle Q. + // using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + using Smem_tile_q = fmha::Smem_tile_a; + // The shared memory tile to reload Q as fragment b. + using Smem_tile_qt = fmha::Smem_tile_b; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + // The shared memory tile to swizzle K. + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + // The global memory tile to store dV. + using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle dV. + using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; + static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); + static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + using Smem_tile_st = typename Kernel_traits::Smem_tile_st; + using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; + + // Shared memory. + extern __shared__ char smem_[]; + + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if (binfo.stop_early()) return; + Mask mask(params, binfo, tidx); + + // Allocate the global memory tile loader for Q. + Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q + // Allocate the shared memory tile loader for Q. + Smem_tile_q smem_q(&smem_[0], tidx); + Smem_tile_qt smem_qt(&smem_[0], tidx); + Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K + // Allocate the shared memory tile loader for K. + Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); + + // Trigger the loads for Q. + gmem_q.load(smem_q); + // Trigger the loads for K. + gmem_k.load(smem_k); + + // Commit the data for Q and K to shared memory. + gmem_q.commit(smem_q); + gmem_k.commit(smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + + // Load the fragments for Q. + typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; + smem_q.load(frag_q[0], 0); + + typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; + static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); + static_assert(Mma_tile_dv::MMAS_K == 1); + smem_qt.load(frag_qt[0], 0); + + // Load the fragments for K. We keep the data in registers during the entire kernel. + typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; + smem_k.load(frag_k[0], 0); + + enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + + Gmem_tile_s gmem_s(params, binfo, tidx); + + // Create the object to do the softmax. + using Softmax = fmha::Softmax; + Softmax softmax(params, + &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], + bidb, tidx); + + enum { THREADS_PER_ROW = 32 }; + enum { M = Mma_tile_p::MMAS_M }; + enum { N = Mma_tile_p::MMAS_N }; + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dv); + + enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; + // Load over the entire sequence length. + for (int l = 0; l < STEPS; l++) { + const int loop = l * Cta_tile_p::M; + if (loop >= binfo.actual_seqlen) break; + + // Load S + uint4 s_regs[M][N]; + gmem_s.load(s_regs, mask); + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_q.load(frag_q[ki & 1], ki); + smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + + // Store s * dmask to smem for transpose + smem_s.store(s_regs); + + // Declare the accumulators for the 1st gemm. + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe + if (l < STEPS - 1) { + smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(smem_q); + } + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack(acc_p); + + float s_mat[2 * M][4 * N]; + +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 &dst = s_regs[mi][ni]; + fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); + fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); + fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); + fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); + } + } + +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ii = 0; ii < 2; ii++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int jj = 0; jj < 4; jj++) { + float &s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; + const bool drop = reinterpret_cast(s_dmask) & 0x80000000; + const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; + s_dmask = fabsf(s_dmask); + softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask); + } } - - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack(acc_p); - - float s_mat[2 * M][4 * N]; - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); - } + } + } + + float p_sum[2 * M]; + softmax.reduce_sum(p_sum); + + const float scalef = reinterpret_cast(params.scale_softmax); +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ii = 0; ii < 2; ii++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int jj = 0; jj < 4; jj++) { + softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]); + softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; + } } - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; - const bool drop = reinterpret_cast(s_dmask) & 0x80000000; - const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; - s_dmask = fabsf(s_dmask); - softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask); - } - } - } + } + } + typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; + smem_s.load(frag_s); + for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) { + for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) { + for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) { + frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); + frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); } + } + } - float p_sum[2 * M]; - softmax.reduce_sum(p_sum); - - const float scalef = reinterpret_cast(params.scale_softmax); - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ; - softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; - } - } - } - } - typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; - smem_s.load(frag_s); - for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) { - for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) { - for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) { - frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); - frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); - } - } - } + gmem_s.store(softmax.elt_, mask); + gmem_s.move(); - gmem_s.store(softmax.elt_, mask); - gmem_s.move(); +#pragma unroll + for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_qt.load(frag_qt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_dv::MMAS_K; + fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + // Commit the values for Q into shared memory. + if (l < STEPS - 1) { + gmem_q.commit(smem_q); + } + + // Make sure we are reading from the correct buffer. + smem_q.move_to_next_read_buffer(); + smem_qt.move_to_next_read_buffer(); - #pragma unroll - for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } + // Make sure the data is in shared memory. + __syncthreads(); - // Do the final stage of math. - { - int ki = Mma_tile_dv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - // Commit the values for Q into shared memory. - if(l < STEPS - 1) { - gmem_q.commit(smem_q); - } + // Trigger the loads for the values of Q for the next iteration. + smem_q.load(frag_q[0], 0); + smem_k.load(frag_k[0], 0); + smem_qt.load(frag_qt[0], 0); - // Make sure we are reading from the correct buffer. - smem_q.move_to_next_read_buffer(); - smem_qt.move_to_next_read_buffer(); + } // Outer loop over the sequence length. + + // Epilogue swizzle for dV + Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); + smem_dv.store(acc_dv); + + __syncthreads(); + uint4 dv_out[Smem_tile_dv::NUM_LDS]; + smem_dv.load(dv_out); + Qkv_params dv_params; + dv_params.qkv_ptr = params.dqkv_ptr; + dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; + dv_params.h = params.h; + Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); + gmem_dv.store(dv_out); +} - // Make sure the data is in shared memory. +template +inline __device__ void compute_dq_dk_1xN(const Params ¶ms) { + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_dk = + fmha::Cta_tile_extd; + static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); + static_assert(Cta_tile_dk::N == 64); + static_assert(Cta_tile_dk::K == 16); + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + using Mma_tile_o = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_dk = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + // The shared memory tile to swizzle Q. + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle K. + using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_o = fmha::Gmem_tile_dq; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + // The global memory tile to store dK. + using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle dK. + using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; + static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); + static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); + + // The shared memory tile to reload Q transposed. + using Smem_tile_qt = fmha::Smem_tile_b; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Smem_tile_st = typename Kernel_traits::Smem_tile_st; + + enum { M = Mma_tile_p::MMAS_M }; + enum { N = Mma_tile_p::MMAS_N }; + static_assert(M == Mma_tile_o::MMAS_M); + static_assert(N == Mma_tile_o::MMAS_K); + // Shared memory. + extern __shared__ char smem_[]; + + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if (binfo.stop_early()) return; + + Mask mask(params, binfo, tidx); + + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params, 0, binfo, tidx); + // Allocate the shared memory tile loader for Q. + Smem_tile_q smem_q(&smem_[0], tidx); + Smem_tile_qt smem_qt(&smem_[0], tidx); + Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], + tidx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params, 1, binfo, tidx); + // Allocate the shared memory tile loader for K. + Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params, binfo, tidx); + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); + + // Trigger the loads for Q. + gmem_q.load(smem_q); + // Trigger the loads for K. + gmem_k.load(smem_k); + + Gmem_tile_s gmem_s(params, binfo, tidx); + // Load dP + uint4 s_regs[M][N]; + gmem_s.load(s_regs, mask); + gmem_s.move(); + + // Commit the data for Q and K to shared memory. + gmem_q.commit(smem_q); + gmem_k.commit(smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + + typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; + smem_qt.load(frag_qt[0], 0); + typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; + smem_k.load(frag_k[0], 0); + + enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + + enum { THREADS_PER_ROW = 32 }; + enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dk); + + // Load over the entire sequence length. + for (int l = 0; l < STEPS; l++) { + const int loop = l * Cta_tile_p::M; + if (loop >= binfo.actual_seqlen) break; + + // Pack dP as Fragment_a + fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 &dst = s_regs[mi][ni]; + frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1 + frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1 + frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9 + frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9 + } + } + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator::apply(acc_o); + +// Do this part of O = P^T * V^T. dQ = dP x dK +#pragma unroll + for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_o::MMAS_K; + fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); + } + + // Store dP to smem for transpose + smem_s.store(s_regs); + if (l < STEPS - 1) { + // Load next part of S + gmem_s.load(s_regs, mask); + gmem_s.move(); + smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(smem_q); + } +// Loop over MMAS_M. +#pragma unroll + for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, ii); + + // Make sure the data is in shared memory. + __syncthreads(); + + // Load from shared memory. + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + smem_o.load(out); + + // Make sure the data was read from shared memory. + if (ii < Gmem_tile_o::LOOPS - 1) { __syncthreads(); + } - // Trigger the loads for the values of Q for the next iteration. - smem_q.load(frag_q[0], 0); - smem_k.load(frag_k[0], 0); - smem_qt.load(frag_qt[0], 0); + // Output the values. + gmem_o.store(out, ii); + } - } // Outer loop over the sequence length. + // Move to the next part of the output. + gmem_o.move(); - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); - smem_dv.store(acc_dv); + typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; + smem_s.load(frag_s); - __syncthreads(); - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dqkv_ptr; - dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); - gmem_dv.store(dv_out); -} +#pragma unroll + for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_qt.load(frag_qt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } -template -inline __device__ void compute_dq_dk_1xN(const Params ¶ms) { + // Do the final stage of math. + { + int ki = Mma_tile_dk::MMAS_K; + fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dk = - fmha::Cta_tile_extd; - static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); - static_assert(Cta_tile_dk::N == 64); - static_assert(Cta_tile_dk::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - using Mma_tile_o = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dk = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_o = fmha::Gmem_tile_dq; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dK. - using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - - - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - static_assert(M == Mma_tile_o::MMAS_M); - static_assert(N == Mma_tile_o::MMAS_K); - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - - Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - Gmem_tile_s gmem_s(params, binfo, tidx); - // Load dP - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - gmem_s.move(); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); + // Commit the values for Q into shared memory. + if (l < STEPS - 1) { + gmem_q.commit(smem_q); + } // Make sure the data is in shared memory. __syncthreads(); - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; + // Trigger the loads for the values of Q for the next iteration. smem_qt.load(frag_qt[0], 0); - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; smem_k.load(frag_k[0], 0); - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - enum { THREADS_PER_ROW = 32 }; - enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for( int l=0;l= binfo.actual_seqlen ) - break; - - // Pack dP as Fragment_a - fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1 - frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1 - frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9 - frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9 - } - } - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. dQ = dP x dK - #pragma unroll - for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_o::MMAS_K; - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Store dP to smem for transpose - smem_s.store(s_regs); - if(l < STEPS - 1) { - // Load next part of S - gmem_s.load(s_regs, mask); - gmem_s.move(); - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); - } - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); - - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } - - // Output the values. - gmem_o.store(out, ii); - } - - // Move to the next part of the output. - gmem_o.move(); - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; - smem_s.load(frag_s); - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dk::MMAS_K; - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Commit the values for Q into shared memory. - if( l < STEPS - 1) { - gmem_q.commit(smem_q); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Trigger the loads for the values of Q for the next iteration. - smem_qt.load(frag_qt[0], 0); - smem_k.load(frag_k[0], 0); - - } // Outer loop over the sequence length. - - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[0], tidx); - smem_dk.store(acc_dk); - __syncthreads(); - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - Qkv_params dk_params; - dk_params.qkv_ptr = params.dqkv_ptr; - dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); - gmem_dk.store(dk_out); + } // Outer loop over the sequence length. + + // Epilogue swizzle for dK + Smem_tile_dk smem_dk(&smem_[0], tidx); + smem_dk.store(acc_dk); + __syncthreads(); + uint4 dk_out[Smem_tile_dk::NUM_LDS]; + smem_dk.load(dk_out); + Qkv_params dk_params; + dk_params.qkv_ptr = params.dqkv_ptr; + dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; + dk_params.h = params.h; + Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); + gmem_dk.store(dk_out); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h index 26776d484..1c04474d1 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,543 +27,539 @@ #pragma once -#include "fmha_kernel.h" -#include #include +#include + +#include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dv_1xN_nl(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dv = fmha::Cta_tile_extd; - - static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); - static_assert(Cta_tile_dv::N == 64); - static_assert(Cta_tile_dv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = fmha::Smem_tile_a; - // The shared memory tile to reload Q as fragment b. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store dV. - using Gmem_tile_dv = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); - static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the chunk. - const int bidc = blockIdx.z; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - Gmem_tile_s gmem_s(params, binfo, tidx); - - using Noloop = Noloop_traits; - - Noloop nl_traits(bidc, binfo); - nl_traits.move_all(gmem_q, gmem_s); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load the fragments for Q. - typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; - smem_q.load(frag_q[0], 0); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - // Load the fragments for K. We keep the data in registers during the entire kernel. - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - // Create the object to do the softmax. - using Softmax = fmha::Softmax; - Softmax softmax( - params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); - - enum { THREADS_PER_ROW = 32 }; - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - - // Load over the entire sequence length. - for(int l = 0; l < nl_traits.num_steps_;l++) { - - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_q.load(frag_q[ki & 1], ki); - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - - smem_s.store(s_regs); - - // Declare the accumulators for the 1st gemm. - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe - if(l < nl_traits.num_steps_ - 1) { - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); - } - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack(acc_p); - - float s_mat[2 * M][4 * N]; - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); - } - } - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; - const bool drop = reinterpret_cast(s_dmask) & 0x80000000; - const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; - s_dmask = fabsf(s_dmask); - softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask); - } - } - } - } - - float p_sum[2 * M]; - softmax.reduce_sum(p_sum); - - const float scalef = reinterpret_cast(params.scale_softmax); - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ; - softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; - } - } - } - } - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; - smem_s.load(frag_s); - for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) { - for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) { - for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) { - frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); - frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); - } - } - } - - gmem_s.store(softmax.elt_, mask); - gmem_s.move(); - - static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG - #pragma unroll - for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_dv = + fmha::Cta_tile_extd; + + static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); + static_assert(Cta_tile_dv::N == 64); + static_assert(Cta_tile_dv::K == 16); + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_dv = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + // The shared memory tile to reload Q as fragment b. + using Smem_tile_qt = fmha::Smem_tile_b; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + // The shared memory tile to swizzle K. + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store dV. + using Gmem_tile_dv = fmha::Gmem_tile_qkv; + + // The shared memory tile to swizzle dV. + using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; + static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); + static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + using Smem_tile_st = typename Kernel_traits::Smem_tile_st; + using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; + + // Shared memory. + extern __shared__ char smem_[]; + + // The block index for the chunk. + const int bidc = blockIdx.z; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if (binfo.stop_early()) return; + fmha::Mask mask(params, binfo, tidx); + + // Allocate the global memory tile loader for Q. + Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q + // Allocate the shared memory tile loader for Q. + Smem_tile_q smem_q(&smem_[0], tidx); + Smem_tile_qt smem_qt(&smem_[0], tidx); + Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K + // Allocate the shared memory tile loader for K. + Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); + + Gmem_tile_s gmem_s(params, binfo, tidx); + + using Noloop = Noloop_traits; + + Noloop nl_traits(bidc, binfo); + nl_traits.move_all(gmem_q, gmem_s); + + // Trigger the loads for Q. + gmem_q.load(smem_q); + // Trigger the loads for K. + gmem_k.load(smem_k); + + // Commit the data for Q and K to shared memory. + gmem_q.commit(smem_q); + gmem_k.commit(smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + + // Load the fragments for Q. + typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; + smem_q.load(frag_q[0], 0); + + typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; + static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); + static_assert(Mma_tile_dv::MMAS_K == 1); + smem_qt.load(frag_qt[0], 0); + + // Load the fragments for K. We keep the data in registers during the entire kernel. + typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; + smem_k.load(frag_k[0], 0); + + enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + + // Create the object to do the softmax. + using Softmax = fmha::Softmax; + Softmax softmax(params, + &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], + bidb, tidx); + + enum { THREADS_PER_ROW = 32 }; + enum { M = Mma_tile_p::MMAS_M }; + enum { N = Mma_tile_p::MMAS_N }; + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dv); + + // Load over the entire sequence length. + for (int l = 0; l < nl_traits.num_steps_; l++) { + uint4 s_regs[M][N]; + gmem_s.load(s_regs, mask); + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_q.load(frag_q[ki & 1], ki); + smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + + smem_s.store(s_regs); + + // Declare the accumulators for the 1st gemm. + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe + if (l < nl_traits.num_steps_ - 1) { + smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(smem_q); + } + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack(acc_p); + + float s_mat[2 * M][4 * N]; + +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 &dst = s_regs[mi][ni]; + fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); + fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); + fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); + fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); + } + } + +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ii = 0; ii < 2; ii++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int jj = 0; jj < 4; jj++) { + float &s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; + const bool drop = reinterpret_cast(s_dmask) & 0x80000000; + const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; + s_dmask = fabsf(s_dmask); + softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask); + } } - - // Do the final stage of math. - { - int ki = Mma_tile_dv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + } + + float p_sum[2 * M]; + softmax.reduce_sum(p_sum); + + const float scalef = reinterpret_cast(params.scale_softmax); +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ii = 0; ii < 2; ii++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int jj = 0; jj < 4; jj++) { + softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]); + softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; + } } - // Commit the values for Q into shared memory. - if(l < nl_traits.num_steps_ - 1) { - gmem_q.commit(smem_q); + } + } + + typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; + smem_s.load(frag_s); + for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) { + for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) { + for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) { + frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); + frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); } - - // Make sure we are reading from the correct buffer. - smem_q.move_to_next_read_buffer(); - smem_qt.move_to_next_read_buffer(); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Trigger the loads for the values of Q for the next iteration. - smem_q.load(frag_q[0], 0); - smem_k.load(frag_k[0], 0); - smem_qt.load(frag_qt[0], 0); - - } // Outer loop over the sequence length. - - // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this! - - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); - smem_dv.store(acc_dv); - - __syncthreads(); - - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dkv_ptr; - dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx); - gmem_dv.store(dv_out); -} - -template -inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dk = fmha::Cta_tile_extd; - - static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); - static_assert(Cta_tile_dk::N == 64); - static_assert(Cta_tile_dk::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - using Mma_tile_o = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dk = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = Gmem_tile_dq; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dK. - using Gmem_tile_dk = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load dP, stored in S - using Gmem_tile_s = Gmem_tile_mma_s; - // The shared memory tile to transpose dP. - using Smem_tile_st = Smem_tile_mma_transposed; - - using Noloop = Noloop_traits; - - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - static_assert(M == Mma_tile_o::MMAS_M); - static_assert(N == Mma_tile_o::MMAS_K); - // Shared memory. - extern __shared__ char smem_[]; - - const int bidc = blockIdx.z; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the shared memory tile loader for Q (as B). - Smem_tile_qt smem_qt(&smem_[0], tidx); - // Allocate the global memory tile loader for dP. - Gmem_tile_s gmem_s(params, binfo, tidx); - // Allocate the shared memory tile loader for dP. - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - Noloop nl_traits(bidc, binfo); - - nl_traits.move_all(gmem_q, gmem_o, gmem_s); - - // Trigger the loads for Q. - gmem_q.load(smem_qt); - // Trigger the loads for K. - gmem_k.load(smem_k); - - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_qt); - gmem_k.commit(smem_k); + } + } + + gmem_s.store(softmax.elt_, mask); + gmem_s.move(); + + static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG +#pragma unroll + for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_qt.load(frag_qt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_dv::MMAS_K; + fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + // Commit the values for Q into shared memory. + if (l < nl_traits.num_steps_ - 1) { + gmem_q.commit(smem_q); + } + + // Make sure we are reading from the correct buffer. + smem_q.move_to_next_read_buffer(); + smem_qt.move_to_next_read_buffer(); // Make sure the data is in shared memory. __syncthreads(); - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; - smem_qt.load(frag_qt[0], 0); - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; + // Trigger the loads for the values of Q for the next iteration. + smem_q.load(frag_q[0], 0); smem_k.load(frag_k[0], 0); + smem_qt.load(frag_qt[0], 0); - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - enum { THREADS_PER_ROW = 32 }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for(int l=0;l < nl_traits.num_steps_; l++) { - - // Pack dP as Fragment_a - fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - frag_p[ni][mi].reg(0) = dst.x; - frag_p[ni][mi].reg(1) = dst.z; - frag_p[ni][mi].reg(2) = dst.y; - frag_p[ni][mi].reg(3) = dst.w; - } - } - smem_s.store(s_regs); - if(l < nl_traits.num_steps_- 1) { - // Load next part of S - gmem_s.move(); - gmem_s.load(s_regs, mask); - // Trigger the load for the next Q values. - smem_qt.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_qt); - } - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. dQ = dP x dK - #pragma unroll - for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_o::MMAS_K; - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); - - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } - - // Output the values. - gmem_o.store(out, ii); - } - - // Move to the next part of the output. - gmem_o.move(); - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; - smem_s.load(frag_s); - - static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG + } // Outer loop over the sequence length. - #pragma unroll - for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } + // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this! - // Do the final stage of math. - { - int ki = Mma_tile_dk::MMAS_K; - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Commit the values for Q into shared memory. - if(l < nl_traits.num_steps_- 1) { - gmem_q.commit(smem_qt); - __syncthreads(); - // Trigger the loads for the values of Q for the next iteration. - smem_qt.load(frag_qt[0], 0); - smem_k.load(frag_k[0], 0); - } + // Epilogue swizzle for dV + Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); + smem_dv.store(acc_dv); - } // Outer loop over the sequence length. + __syncthreads(); - // Epilogue for dK = dP' * dq. We're fully exposed to this! + uint4 dv_out[Smem_tile_dv::NUM_LDS]; + smem_dv.load(dv_out); + Qkv_params dv_params; + dv_params.qkv_ptr = params.dkv_ptr; + dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); + dv_params.h = params.h; + Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx); + gmem_dv.store(dv_out); +} - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[0], tidx); - smem_dk.store(acc_dk); - - __syncthreads(); - - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - Qkv_params dk_params; - dk_params.qkv_ptr = params.dkv_ptr; - dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx); - gmem_dk.store(dk_out); +template +inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_dk = + fmha::Cta_tile_extd; + + static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); + static_assert(Cta_tile_dk::N == 64); + static_assert(Cta_tile_dk::K == 16); + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + using Mma_tile_o = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_dk = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + // The shared memory tile to swizzle Q. + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle K. + using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = Gmem_tile_dq; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + // The global memory tile to store dK. + using Gmem_tile_dk = fmha::Gmem_tile_qkv; + + // The shared memory tile to swizzle dK. + using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; + static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); + static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); + + // The shared memory tile to reload Q transposed. + using Smem_tile_qt = fmha::Smem_tile_b; + + // The global memory tile to load dP, stored in S + using Gmem_tile_s = Gmem_tile_mma_s; + // The shared memory tile to transpose dP. + using Smem_tile_st = Smem_tile_mma_transposed; + + using Noloop = Noloop_traits; + + enum { M = Mma_tile_p::MMAS_M }; + enum { N = Mma_tile_p::MMAS_N }; + static_assert(M == Mma_tile_o::MMAS_M); + static_assert(N == Mma_tile_o::MMAS_K); + // Shared memory. + extern __shared__ char smem_[]; + + const int bidc = blockIdx.z; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if (binfo.stop_early()) return; + + fmha::Mask mask(params, binfo, tidx); + + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params, 0, binfo, tidx); + // Allocate the shared memory tile loader for Q (as B). + Smem_tile_qt smem_qt(&smem_[0], tidx); + // Allocate the global memory tile loader for dP. + Gmem_tile_s gmem_s(params, binfo, tidx); + // Allocate the shared memory tile loader for dP. + Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], + tidx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params, 1, binfo, tidx); + // Allocate the shared memory tile loader for K. + Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params, binfo, tidx); + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); + + Noloop nl_traits(bidc, binfo); + + nl_traits.move_all(gmem_q, gmem_o, gmem_s); + + // Trigger the loads for Q. + gmem_q.load(smem_qt); + // Trigger the loads for K. + gmem_k.load(smem_k); + + uint4 s_regs[M][N]; + gmem_s.load(s_regs, mask); + + // Commit the data for Q and K to shared memory. + gmem_q.commit(smem_qt); + gmem_k.commit(smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + + typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; + smem_qt.load(frag_qt[0], 0); + typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; + smem_k.load(frag_k[0], 0); + + enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + + enum { THREADS_PER_ROW = 32 }; + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dk); + + // Load over the entire sequence length. + for (int l = 0; l < nl_traits.num_steps_; l++) { + // Pack dP as Fragment_a + fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 &dst = s_regs[mi][ni]; + frag_p[ni][mi].reg(0) = dst.x; + frag_p[ni][mi].reg(1) = dst.z; + frag_p[ni][mi].reg(2) = dst.y; + frag_p[ni][mi].reg(3) = dst.w; + } + } + smem_s.store(s_regs); + if (l < nl_traits.num_steps_ - 1) { + // Load next part of S + gmem_s.move(); + gmem_s.load(s_regs, mask); + // Trigger the load for the next Q values. + smem_qt.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(smem_qt); + } + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator::apply(acc_o); + +// Do this part of O = P^T * V^T. dQ = dP x dK +#pragma unroll + for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_o::MMAS_K; + fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); + } + + static_assert(Gmem_tile_o::LOOPS == 1); // DEBUG +// Loop over MMAS_M. +#pragma unroll + for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, ii); + + // Make sure the data is in shared memory. + __syncthreads(); + + // Load from shared memory. + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + smem_o.load(out); + + // Make sure the data was read from shared memory. + if (ii < Gmem_tile_o::LOOPS - 1) { + __syncthreads(); + } + + // Output the values. + gmem_o.store(out, ii); + } + + // Move to the next part of the output. + gmem_o.move(); + + typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; + smem_s.load(frag_s); + + static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG + +#pragma unroll + for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + smem_qt.load(frag_qt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_dk::MMAS_K; + fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Commit the values for Q into shared memory. + if (l < nl_traits.num_steps_ - 1) { + gmem_q.commit(smem_qt); + __syncthreads(); + // Trigger the loads for the values of Q for the next iteration. + smem_qt.load(frag_qt[0], 0); + smem_k.load(frag_k[0], 0); + } + + } // Outer loop over the sequence length. + + // Epilogue for dK = dP' * dq. We're fully exposed to this! + + // Epilogue swizzle for dK + Smem_tile_dk smem_dk(&smem_[0], tidx); + smem_dk.store(acc_dk); + + __syncthreads(); + + uint4 dk_out[Smem_tile_dk::NUM_LDS]; + smem_dk.load(dk_out); + Qkv_params dk_params; + dk_params.qkv_ptr = params.dkv_ptr; + dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); + dk_params.h = params.h; + Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx); + gmem_dk.store(dk_out); } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace fmha +} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_fill.cu b/apex/contrib/csrc/fmha/src/fmha_fill.cu index 0ee2adbea..f2e0d925d 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fill.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fill.cu @@ -25,45 +25,42 @@ * ******************************************************************************/ -#include -#include #include +#include +#include constexpr int block_size = 512; constexpr int ctas_per_sm = 4; template -__global__ void -__launch_bounds__(block_size) -mha_fill_kernel(scalar_t* out_tensor, - const int32_t* const start_row, - const size_t num_rows) { - size_t row_stride = gridDim.y * blockDim.x; - size_t row_index = blockIdx.x + (size_t)start_row[0]; - size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; - while (row_index < num_rows) { - out_tensor[row_index*row_stride + col_index] = 0; - row_index += gridDim.x; - } +__global__ void __launch_bounds__(block_size) + mha_fill_kernel(scalar_t* out_tensor, const int32_t* const start_row, const size_t num_rows) { + size_t row_stride = gridDim.y * blockDim.x; + size_t row_index = blockIdx.x + (size_t)start_row[0]; + size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; + while (row_index < num_rows) { + out_tensor[row_index * row_stride + col_index] = 0; + row_index += gridDim.x; + } } -at::Tensor & mha_fill(at::Tensor &self, const at::Tensor &start_index) { - auto max_tokens = self.size(0); - auto self_2d = self.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - TORCH_CHECK (self.is_contiguous(), "input not contiguous"); - TORCH_CHECK (fcd_size % block_size == 0, "input size not aligned to block size"); - const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); - uint64_t num_blk_x = (uint64_t)std::ceil(num_mp * ctas_per_sm / num_blk_y); - dim3 dim_grid(num_blk_x, num_blk_y); - dim3 dim_block(block_size); +at::Tensor& mha_fill(at::Tensor& self, const at::Tensor& start_index) { + auto max_tokens = self.size(0); + auto self_2d = self.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + TORCH_CHECK(self.is_contiguous(), "input not contiguous"); + TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); + const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); + uint64_t num_blk_x = (uint64_t)std::ceil(num_mp * ctas_per_sm / num_blk_y); + dim3 dim_grid(num_blk_x, num_blk_y); + dim3 dim_block(block_size); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_padding_fill_", [&]() { - mha_fill_kernel<<>>( - self_2d.data_ptr(), start_index.data_ptr(), max_tokens); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - return self; + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_padding_fill_", [&]() { + mha_fill_kernel<<>>( + self_2d.data_ptr(), start_index.data_ptr(), max_tokens); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return self; } diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu index 4664e46ad..29ae36103 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -30,50 +30,44 @@ using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; -template -__global__ -void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int total_heads) { - - fmha::device_1xN(params, total_heads); +template +__global__ void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params, + const int total_heads) { + fmha::device_1xN(params, total_heads); } -void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel : &fmha_fprop_fp16_128_64_sm80_kernel; +void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, + const bool configure) { + auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel + : &fmha_fprop_fp16_128_64_sm80_kernel; - constexpr int smem_size = fmha::get_dynamic_smem_size(); + constexpr int smem_size = fmha::get_dynamic_smem_size(); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; + const int sm_count = launch_params.props->multiProcessorCount; + int ctas_per_sm; + FMHA_CHECK_CUDA( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); + int total_ctas = sm_count * ctas_per_sm; - const int heads_total = launch_params.params.b * launch_params.params.h; - if(configure) { + const int heads_total = launch_params.params.b * launch_params.params.h; + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - using Mma_tile_p = fmha::Hmma_tile; - constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; + launch_params.elts_per_thread = heads_per_cta * elts_per_head; + return; + } - size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; - launch_params.elts_per_thread = heads_per_cta * elts_per_head; - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - heads_total); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + dim3 grid(total_ctas); + kernel<<>>(launch_params.params, heads_total); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } - - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu index 34df228e9..b4e028bef 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -30,50 +30,44 @@ using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; -template -__global__ -void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int total_heads) { - - fmha::device_1xN(params, total_heads); +template +__global__ void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params, + const int total_heads) { + fmha::device_1xN(params, total_heads); } -void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel : &fmha_fprop_fp16_256_64_sm80_kernel; +void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, + const bool configure) { + auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel + : &fmha_fprop_fp16_256_64_sm80_kernel; - constexpr int smem_size = fmha::get_dynamic_smem_size(); + constexpr int smem_size = fmha::get_dynamic_smem_size(); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; + const int sm_count = launch_params.props->multiProcessorCount; + int ctas_per_sm; + FMHA_CHECK_CUDA( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); + int total_ctas = sm_count * ctas_per_sm; - const int heads_total = launch_params.params.b * launch_params.params.h; - if(configure) { + const int heads_total = launch_params.params.b * launch_params.params.h; + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - using Mma_tile_p = fmha::Hmma_tile; - constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; + launch_params.elts_per_thread = heads_per_cta * elts_per_head; + return; + } - size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; - launch_params.elts_per_thread = heads_per_cta * elts_per_head; - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - heads_total); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + dim3 grid(total_ctas); + kernel<<>>(launch_params.params, heads_total); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } - - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu index 5e30452af..402728635 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -30,50 +30,44 @@ using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>; -template -__global__ -void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int total_heads) { - - fmha::device_1xN(params, total_heads); +template +__global__ void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params, + const int total_heads) { + fmha::device_1xN(params, total_heads); } -void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel : &fmha_fprop_fp16_384_64_sm80_kernel; +void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, + const bool configure) { + auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel + : &fmha_fprop_fp16_384_64_sm80_kernel; - constexpr int smem_size = fmha::get_dynamic_smem_size(); + constexpr int smem_size = fmha::get_dynamic_smem_size(); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; + const int sm_count = launch_params.props->multiProcessorCount; + int ctas_per_sm; + FMHA_CHECK_CUDA( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); + int total_ctas = sm_count * ctas_per_sm; - const int heads_total = launch_params.params.b * launch_params.params.h; - if(configure) { + const int heads_total = launch_params.params.b * launch_params.params.h; + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - using Mma_tile_p = fmha::Hmma_tile; - constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; + launch_params.elts_per_thread = heads_per_cta * elts_per_head; + return; + } - size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; - launch_params.elts_per_thread = heads_per_cta * elts_per_head; - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - heads_total); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + dim3 grid(total_ctas); + kernel<<>>(launch_params.params, heads_total); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } - - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu index e37689e8c..e7067c6eb 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -30,108 +30,95 @@ using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>; -template -__global__ -void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int total_heads) { - - fmha::device_1xN(params, total_heads); +template +__global__ void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params, + const int total_heads) { + fmha::device_1xN(params, total_heads); } -template -__global__ -void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - fmha::device_1xN( - params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); +template +__global__ void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params, + const int num_full_heads, const int num_main_groups, + const int main_group_size, const int main_steps, + const int rest_steps) { + fmha::device_1xN(params, num_full_heads, num_main_groups, main_group_size, main_steps, + rest_steps); } -void run_fmha_fp16_512_64_sm80_(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel : &fmha_fprop_fp16_512_64_sm80_kernel; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } +void run_fmha_fp16_512_64_sm80_(Launch_params &launch_params, + const bool configure) { + auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel + : &fmha_fprop_fp16_512_64_sm80_kernel; - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; + constexpr int smem_size = fmha::get_dynamic_smem_size(); - const int heads_total = launch_params.params.b * launch_params.params.h; - if(configure) { + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - using Mma_tile_p = fmha::Hmma_tile; - constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + const int sm_count = launch_params.props->multiProcessorCount; + int ctas_per_sm; + FMHA_CHECK_CUDA( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); + int total_ctas = sm_count * ctas_per_sm; - size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; - launch_params.elts_per_thread = heads_per_cta * elts_per_head; - return; - } + const int heads_total = launch_params.params.b * launch_params.params.h; + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - heads_total); + size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; + launch_params.elts_per_thread = heads_per_cta * elts_per_head; + return; + } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + dim3 grid(total_ctas); + kernel<<>>(launch_params.params, heads_total); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } -void run_fmha_fp16_512_64_sm80_nl_(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl : &fmha_fprop_fp16_512_64_sm80_kernel_nl; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); +void run_fmha_fp16_512_64_sm80_nl_(Launch_params &launch_params, + const bool configure) { + auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl + : &fmha_fprop_fp16_512_64_sm80_kernel_nl; - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + constexpr int smem_size = fmha::get_dynamic_smem_size(); - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } - if(configure) { - const int heads_total = launch_params.params.b * launch_params.params.h; - std::tie(launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps, - launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + const int sm_count = launch_params.props->multiProcessorCount; + int ctas_per_sm; + FMHA_CHECK_CUDA( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); + int total_ctas = sm_count * ctas_per_sm; + if (configure) { + const int heads_total = launch_params.params.b * launch_params.params.h; + std::tie(launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps, launch_params.elts_per_thread) = + fmha::work_dist(total_ctas, heads_total); + return; + } + + dim3 grid(total_ctas); + kernel<<>>( + launch_params.params, launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps); + + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } -void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, const bool configure) { - if( launch_params.is_nl ) { - run_fmha_fp16_512_64_sm80_nl_(launch_params, configure); - } else { - run_fmha_fp16_512_64_sm80_(launch_params, configure); - } +void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, + const bool configure) { + if (launch_params.is_nl) { + run_fmha_fp16_512_64_sm80_nl_(launch_params, configure); + } else { + run_fmha_fp16_512_64_sm80_(launch_params, configure); + } } diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h b/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h index 3c3e8ba14..0e270fc20 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h @@ -1,6 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,505 +27,476 @@ #pragma once -#include "fmha_kernel.h" -#include #include +#include + +#include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Gemm_Q_K_base { - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - using Fragment_q = typename Smem_tile_q::Fragment; - using Fragment_k = typename Smem_tile_k::Fragment; - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + using Fragment_q = typename Smem_tile_q::Fragment; + using Fragment_k = typename Smem_tile_k::Fragment; - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; - __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) - : smem_q(smem_ptr_q, tidx) - , smem_k(smem_ptr_k, tidx) { + static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; - } + __device__ inline Gemm_Q_K_base(char *smem_ptr_q, char *smem_ptr_k, const int tidx) + : smem_q(smem_ptr_q, tidx), smem_k(smem_ptr_k, tidx) {} - __device__ inline void load_q() { - smem_q.load(frag_q[0], 0); - } + __device__ inline void load_q() { smem_q.load(frag_q[0], 0); } - __device__ inline void reload_q() { - smem_q.load(frag_q[0], 0); - } + __device__ inline void reload_q() { smem_q.load(frag_q[0], 0); } - Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; - Smem_tile_q smem_q; - Smem_tile_k smem_k; + Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; + Smem_tile_q smem_q; + Smem_tile_k smem_k; }; -template +template struct Gemm_Q_K : public Gemm_Q_K_base { - - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - - enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; - - enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE }; - enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; - - // Q | K / V - // | O | SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, - Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + + enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; + + enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE }; + enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; + + // Q | K / V + // | O | SOFTMAX + static constexpr int SMEM_BYTES = + Smem_tile_q::BYTES_PER_TILE + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); + + __device__ inline Gemm_Q_K(char *smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} + + __device__ inline void load_k() { +#pragma unroll + for (int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) { + Base::smem_k.load(frag_k[ki], ki); } - - __device__ inline void load_k(){ - #pragma unroll - for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { - Base::smem_k.load(frag_k[ki], ki); - } + } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]) { +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } + } - __device__ inline void reload_k(){ - // Noop. - } + __device__ inline void reload_k() { + // Noop. + } - Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; + Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; }; - -template +template struct Gemm_Q_K : public Gemm_Q_K_base { - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; - - enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; - - enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; - static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE); - enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE }; - - // Q | K/V + O + SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE - + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; + + enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; + + enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; + static_assert(Smem_tile_v::BYTES_PER_TILE == (int)Smem_tile_k::BYTES_PER_TILE); + enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE }; + + // Q | K/V + O + SOFTMAX + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; + + __device__ inline Gemm_Q_K(char *smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} + + __device__ inline void load_k() { Base::smem_k.load(frag_k[0], 0); } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]) { +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + Base::smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } - - __device__ inline void load_k(){ - Base::smem_k.load(frag_k[0], 0); - } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - Base::smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } + } - __device__ inline void reload_k(){ - Base::smem_k.load(frag_k[0], 0); - } + __device__ inline void reload_k() { Base::smem_k.load(frag_k[0], 0); } }; -template -constexpr size_t get_dynamic_smem_size(){ - return Gemm_Q_K::SMEM_BYTES; +template +constexpr size_t get_dynamic_smem_size() { + return Gemm_Q_K::SMEM_BYTES; } -template -inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) { - - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_o = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Gemm1 = Gemm_Q_K; +template +inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, const int begin, + const int steps, Prng &ph) { + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // The number of threads per row. + enum { THREADS_PER_ROW = 32 }; + + enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if (binfo.stop_early()) return; + + Gemm1 gemm_q_k(smem_, tidx); + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params, 0, binfo, tidx); + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params, binfo, tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + // Wind gmem tiles to the correct position. + for (int it = 0; it < begin; it++) { + gmem_q.move(); + gmem_s.move(); + gmem_o.move(); + } + + fmha::Mask mask(params, binfo, tidx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params, 1, binfo, tidx); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v(params, 2, binfo, tidx); + // The base pointer of smem_v; + char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + // Trigger the loads for K. + gmem_k.load(gemm_q_k.smem_k); + // Trigger the loads for Q. + gmem_q.load(gemm_q_k.smem_q); + // Trigger the loads for V. + gmem_v.load(smem_v); + + const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); +#pragma unroll + for (int it = 0; it < Gmem_tile_k::LDGS; it++) { + gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + } + + // Commit the data for Q and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_v.commit(smem_v); + + // Commit the data for K to shared memory. + if (!Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire kernel. + typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + smem_v.load(frag_v[ki], ki); + } + + // Commit the data for V to shared memory if it has not been done already. + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + // Make sure we are done loading the fragments for K. + __syncthreads(); - using Softmax = fmha::Softmax; + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + // Make sure the data is in shared memory. + __syncthreads(); + } - // The number of threads per row. - enum { THREADS_PER_ROW = 32 }; + // Load the fragments for K. + gemm_q_k.load_k(); + uint32_t p_scaled = (uint32_t)256.0 * params.p_dropout; - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); - // Shared memory. - extern __shared__ char smem_[]; + // Load over the entire sequence length. + for (int l = 0; l < steps; l++) { + if (begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break; - // The thread index. - const int tidx = threadIdx.x; + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) return; + // Do this part of P^T = (Q * K^T)^T. + gemm_q_k(acc_p); - Gemm1 gemm_q_k(smem_, tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - // Wind gmem tiles to the correct position. - for( int it = 0; it < begin; it++ ) { - gmem_q.move(); - gmem_s.move(); - gmem_o.move(); + // Trigger the load for the next Q values. + if (l < steps - 1) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(gemm_q_k.smem_q); } - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); - - // Trigger the loads for K. - gmem_k.load(gemm_q_k.smem_k); - // Trigger the loads for Q. - gmem_q.load(gemm_q_k.smem_q); - // Trigger the loads for V. - gmem_v.load(smem_v); - - const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - #pragma unroll - for(int it=0;it < Gmem_tile_k::LDGS;it++){ - gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - } + // Load the mask for that iteration. + mask.load(begin + l); + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + // Apply the mask. + softmax.apply_mask(mask); - // Commit the data for Q and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_v.commit(smem_v); - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0) { + // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction + __syncthreads(); } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + // softmax.template reduce(p_max); + softmax.reduce_max(p_max); + + // Compute the exponential value. + softmax.apply_exp(p_max); + + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + softmax.reduce_sum(p_sum); + + // Finalize softmax on the accumulators of P^T. + softmax.scale(p_sum); + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + if (Is_training) { + auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; }; +#pragma unroll + for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) { +#pragma unroll + for (int ii = 0; ii < 2; ii++) { +#pragma unroll + for (int ni = 0; ni < Mma_tile_p::MMAS_N / 4; ni++) { + uint8_t *rand_arr = (uint8_t *)&ph(); + // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from + // pre-existing zeros + for (int ind = 0; ind < 16; ind++) { + softmax.elt_[2 * mi + ii][16 * ni + ind] = + encode_dropout(rand_arr[ind] <= p_scaled, softmax.elt_[2 * mi + ii][16 * ni + ind]); + } + } + } + } + softmax.pack(frag_p); + gmem_s.store(frag_p, mask); + gmem_s.move(); + } else { + softmax.pack(frag_p); } - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); + // Commit the values for Q into shared memory. + if (l < steps - 1) { + gmem_q.commit(gemm_q_k.smem_q); } - // Load the fragments for K. - gemm_q_k.load_k(); - uint32_t p_scaled = (uint32_t) 256.0 * params.p_dropout; - - // Create the object to do the softmax. - Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); - - // Load over the entire sequence length. - for( int l = 0; l < steps; l++ ) { - if(begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break; - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P^T = (Q * K^T)^T. - gemm_q_k(acc_p); - - // Trigger the load for the next Q values. - if( l < steps - 1) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(gemm_q_k.smem_q); + if (Is_training) { +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ki++) { +#pragma unroll + for (int mi = 0; mi < Mma_tile_o::MMAS_M; mi++) { +#pragma unroll + for (int ii = 0; ii < Frag_p::NUM_REGS; ii++) { + //"Apply" the dropout. + frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout); + frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii)); + } } + } + } - // Load the mask for that iteration. - mask.load(begin + l); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - - // Apply the mask. - softmax.apply_mask(mask); - - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { - // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - __syncthreads(); - } - // Compute the max. - float p_max[Mma_tile_p::MMAS_M * 2]; - //softmax.template reduce(p_max); - softmax.reduce_max(p_max); - - // Compute the exponential value. - softmax.apply_exp(p_max); - - // Compute the sum. - float p_sum[Mma_tile_p::MMAS_M * 2]; - softmax.reduce_sum(p_sum); - - // Finalize softmax on the accumulators of P^T. - softmax.scale(p_sum); - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - if( Is_training ) { - auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; }; - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N/4; ni++ ) { - uint8_t * rand_arr = (uint8_t*) &ph(); - // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros - for (int ind=0; ind<16; ind++) - { - softmax.elt_[2 * mi + ii][16 * ni + ind] = - encode_dropout(rand_arr[ind] <= p_scaled, softmax.elt_[2 * mi + ii][16 * ni + ind]); - } - } - } - } - softmax.pack(frag_p); - gmem_s.store(frag_p, mask); - gmem_s.move(); - } else { - softmax.pack(frag_p); - } - - // Commit the values for Q into shared memory. - if(l < steps - 1) { - gmem_q.commit(gemm_q_k.smem_q); - } - - if( Is_training ) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) { - //"Apply" the dropout. - frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout); - frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii)); - } - } - } - } - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); - } + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator::apply(acc_o); - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { +// Do this part of O = P^T * V^T. +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); + } - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); +// Loop over MMAS_M. +#pragma unroll + for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, ii); - // Make sure the data is in shared memory. - __syncthreads(); + // Make sure the data is in shared memory. + __syncthreads(); - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); + // Load from shared memory. + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + smem_o.load(out); - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } + // Make sure the data was read from shared memory. + if (ii < Gmem_tile_o::LOOPS - 1) { + __syncthreads(); + } - // Output the values. - gmem_o.store(out, ii); - } + // Output the values. + gmem_o.store(out, ii); + } - // Move to the next part of the output. - gmem_o.move(); - gemm_q_k.reload_k(); + // Move to the next part of the output. + gmem_o.move(); + gemm_q_k.reload_k(); - // Commit the values for Q into shared memory. - if(l < steps - 1) { - gemm_q_k.reload_q(); - } + // Commit the values for Q into shared memory. + if (l < steps - 1) { + gemm_q_k.reload_q(); + } - } // Outer loop over the sequence length. + } // Outer loop over the sequence length. } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void device_1xN(const Params ¶ms, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - for( int it = 0; it < num_full_heads; it++ ) { - const int bidx = it * gridDim.x + blockIdx.x; - const int bidh = bidx % params.h; - const int bidb = bidx / params.h; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); - __syncthreads(); - } - if( main_group_size == 0 ) - return; - const int head_offset = num_full_heads * gridDim.x; - - if( blockIdx.x < main_group_size * num_main_groups ) { - // process within heads - const int group = blockIdx.x % num_main_groups; - const int bidx = blockIdx.x / num_main_groups; - const int bidh = (head_offset + bidx) % params.h; - const int bidb = (head_offset + bidx) / params.h; - const int offset = group * main_steps; - fmha::device_1xN_(params, bidb, bidh, offset, main_steps, ph); - } else { - if(rest_steps == 0 ) return; - // process across heads - const int bidx = blockIdx.x - main_group_size * num_main_groups; - const int offset = num_main_groups * main_steps; - const int total_heads = params.b * params.h; - const int rest_ctas = gridDim.x - main_group_size * num_main_groups; - for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) { - const int bidh = it % params.h; - const int bidb = it / params.h; - fmha::device_1xN_(params, bidb, bidh, offset, rest_steps, ph); - __syncthreads(); - } +template +inline __device__ void device_1xN(const Params ¶ms, const int num_full_heads, const int num_main_groups, + const int main_group_size, const int main_steps, const int rest_steps) { + constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; + auto seeds = at::cuda::philox::unpack(params.philox_args); + Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + for (int it = 0; it < num_full_heads; it++) { + const int bidx = it * gridDim.x + blockIdx.x; + const int bidh = bidx % params.h; + const int bidb = bidx / params.h; + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); + __syncthreads(); + } + if (main_group_size == 0) return; + const int head_offset = num_full_heads * gridDim.x; + + if (blockIdx.x < main_group_size * num_main_groups) { + // process within heads + const int group = blockIdx.x % num_main_groups; + const int bidx = blockIdx.x / num_main_groups; + const int bidh = (head_offset + bidx) % params.h; + const int bidb = (head_offset + bidx) / params.h; + const int offset = group * main_steps; + fmha::device_1xN_(params, bidb, bidh, offset, main_steps, ph); + } else { + if (rest_steps == 0) return; + // process across heads + const int bidx = blockIdx.x - main_group_size * num_main_groups; + const int offset = num_main_groups * main_steps; + const int total_heads = params.b * params.h; + const int rest_ctas = gridDim.x - main_group_size * num_main_groups; + for (int it = head_offset + bidx; it < total_heads; it += rest_ctas) { + const int bidh = it % params.h; + const int bidb = it / params.h; + fmha::device_1xN_(params, bidb, bidh, offset, rest_steps, ph); + __syncthreads(); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void device_1xN(const Params ¶ms, const int total_heads) { - - const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - - for(int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x){ - const int bidh = bidx % params.h; - const int bidb = bidx / params.h; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); - __syncthreads(); - } + const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; + auto seeds = at::cuda::philox::unpack(params.philox_args); + Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + + for (int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x) { + const int bidh = bidx % params.h; + const int bidb = bidx / params.h; + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace fmha - +} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_kernel.h b/apex/contrib/csrc/fmha/src/fmha_kernel.h index 63180b087..3cdb08148 100644 --- a/apex/contrib/csrc/fmha/src/fmha_kernel.h +++ b/apex/contrib/csrc/fmha/src/fmha_kernel.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,151 +27,141 @@ #pragma once -#include - #include -#include -#include #include #include +#include #include +#include + +#include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct BlockInfoPadded { - - template - __device__ BlockInfoPadded(const Params ¶ms, - const int bidb, - const int bidh, - const int tidx) - : bidb(bidb), bidh(bidh), h(params.h) { - - // The block index. - sum_s = params.cu_seqlens[bidb]; - actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; - bidx = sum_s * params.h + bidh; - - tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; - } - - __device__ bool stop_early() const { - return actual_seqlen == 0; - } - - int actual_seqlen; - int bidx; - int sum_s; - int bidh; - int bidb; - int tidx_global; - int h; + template + __device__ BlockInfoPadded(const Params& params, const int bidb, const int bidh, const int tidx) + : bidb(bidb), bidh(bidh), h(params.h) { + // The block index. + sum_s = params.cu_seqlens[bidb]; + actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; + bidx = sum_s * params.h + bidh; + + tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; + } + + __device__ bool stop_early() const { return actual_seqlen == 0; } + + int actual_seqlen; + int bidx; + int sum_s; + int bidh; + int bidb; + int tidx_global; + int h; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Noloop_traits{ - // Interpretation of Cta_tile dims, i.e. Cta_tile_p: - enum{ STEP = Cta_tile::M }; - enum{ SEQLEN = Cta_tile::N }; - - template - inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) - : bidc_(bidc) { - const int seqlen = binfo.actual_seqlen; - const int steps = (seqlen + STEP - 1) / STEP; - const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; - - const int step_begin = bidc_ * steps_per_chunk; - const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); - const int actual_steps = max(0, step_end - step_begin); - loop_offset_ = step_begin; - num_steps_ = actual_steps; - +template +struct Noloop_traits { + // Interpretation of Cta_tile dims, i.e. Cta_tile_p: + enum { STEP = Cta_tile::M }; + enum { SEQLEN = Cta_tile::N }; + + template + inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) : bidc_(bidc) { + const int seqlen = binfo.actual_seqlen; + const int steps = (seqlen + STEP - 1) / STEP; + const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; + + const int step_begin = bidc_ * steps_per_chunk; + const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); + const int actual_steps = max(0, step_end - step_begin); + loop_offset_ = step_begin; + num_steps_ = actual_steps; + } + + template + inline __device__ void move_all(Tiles&... tiles) const { + using expand_type = int[]; + for (int s = 0; s < loop_offset_; s++) { + expand_type{(tiles.move(), 0)...}; } - - template - inline __device__ void move_all(Tiles & ... tiles) const { - using expand_type = int[]; - for( int s = 0; s < loop_offset_; s++ ) { - expand_type{ (tiles.move(), 0)... }; - } - } - - inline __device__ int get_idx_dk() const { - //return bidc_; - return bidc_ * 2 + 0; - } - - inline __device__ int get_idx_dv() const { - //return CHUNKS + bidc_; - return bidc_ * 2 + 1; - } - - inline __device__ int offset_loop_count(const int l) { - // convert loop counter to position in the outer sequence - return (loop_offset_ + l) * STEP; - } - - const uint32_t bidc_; - int loop_offset_; - int num_steps_; + } + + inline __device__ int get_idx_dk() const { + // return bidc_; + return bidc_ * 2 + 0; + } + + inline __device__ int get_idx_dv() const { + // return CHUNKS + bidc_; + return bidc_ * 2 + 1; + } + + inline __device__ int offset_loop_count(const int l) { + // convert loop counter to position in the outer sequence + return (loop_offset_ + l) * STEP; + } + + const uint32_t bidc_; + int loop_offset_; + int num_steps_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -std::tuple work_dist(const int total_ctas, const int heads_total) { - - constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - - const int num_full_heads = heads_total / total_ctas; - const int heads_last_wave = heads_total % total_ctas; - - int num_main_groups = 0; - int main_steps = 0; - int rest_steps = 0; - if( heads_last_wave > 0 ) { - // Number of CTA groups that process within heads. - num_main_groups = total_ctas / heads_last_wave; - // Remaining CTAs that process between heads. - const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); - if(rest_ctas == 0) { - // We have exactly "num_main_groups" CTAs to process each of the remaining heads. - main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; - num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 - rest_steps = STEPS_PER_HEAD % main_steps; - - } else { - // Ideal number of steps if we could load-balance as evenly as possible. - const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; - // Iterations that a "rest" CTA has to do at most. - const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; - // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs. - main_steps = steps_ideal; - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) { - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - const int max_rest_total_steps = rest_steps * max_rest_iters; - if( max_rest_total_steps < main_steps ) - break; - } - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - } +template +std::tuple work_dist(const int total_ctas, const int heads_total) { + constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + + const int num_full_heads = heads_total / total_ctas; + const int heads_last_wave = heads_total % total_ctas; + + int num_main_groups = 0; + int main_steps = 0; + int rest_steps = 0; + if (heads_last_wave > 0) { + // Number of CTA groups that process within heads. + num_main_groups = total_ctas / heads_last_wave; + // Remaining CTAs that process between heads. + const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); + if (rest_ctas == 0) { + // We have exactly "num_main_groups" CTAs to process each of the remaining heads. + main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; + num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 + rest_steps = STEPS_PER_HEAD % main_steps; + + } else { + // Ideal number of steps if we could load-balance as evenly as possible. + const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; + // Iterations that a "rest" CTA has to do at most. + const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; + // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main + // CTAs. + main_steps = steps_ideal; + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; + for (; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++) { + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; + const int max_rest_total_steps = rest_steps * max_rest_iters; + if (max_rest_total_steps < main_steps) break; + } + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; } + } - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Mma_tile_p = fmha::Hmma_tile; + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + using Mma_tile_p = fmha::Hmma_tile; - const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); - const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; - const int elts_per_thread = max_steps * elts_per_thread_per_step; + const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); + const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; + const int elts_per_thread = max_steps * elts_per_thread_per_step; - return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread}; + return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread}; } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu b/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu index 8e4b9efc3..c9a8d4204 100644 --- a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu +++ b/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,151 +27,140 @@ #include "fmha.h" -inline __device__ float4 ldg128(const void *ptr) { - return *static_cast(ptr); -} +inline __device__ float4 ldg128(const void *ptr) { return *static_cast(ptr); } -inline __device__ void stg128(void *ptr, const float4 &data) { - *static_cast(ptr) = data; -} +inline __device__ void stg128(void *ptr, const float4 &data) { *static_cast(ptr) = data; } -template +template __global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out, const void *__restrict__ in, const int *__restrict__ cu_seqlens, const int batch_size) { + enum { BYTES_PER_LDG = 16 }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) }; - enum { BYTES_PER_LDG = 16 }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) }; + // One CTA hidden vector for K and V + enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 }; + // The stride in bytes in dQKV + enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) }; + // The offset in bytes in dQKV to the dKV part for non-interleaved heads + enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) }; - // One CTA hidden vector for K and V - enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 }; - // The stride in bytes in dQKV - enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) }; - // The offset in bytes in dQKV to the dKV part for non-interleaved heads - enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) }; + static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); - static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); + // Size in bytes of the input tile + enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW }; - // Size in bytes of the input tile - enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW }; + enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG }; - enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG }; + enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA }; + static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW); - enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA }; - static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW); + union Vec_t { + float4 raw; + T elt[NUM_ELTS]; + }; - union Vec_t { - float4 raw; - T elt[NUM_ELTS]; - }; + // ZERO-OUT invalid positions in dQKV + const int total = cu_seqlens[batch_size]; + if (blockIdx.x >= total) { + enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) }; + enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG }; - // ZERO-OUT invalid positions in dQKV - const int total = cu_seqlens[batch_size]; - if(blockIdx.x >= total){ - enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) }; - enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG }; + const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f); - const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f); + char *base_ptr = static_cast(out) + blockIdx.x * OUT_STRIDE_BYTES; - char *base_ptr = static_cast(out) + blockIdx.x * OUT_STRIDE_BYTES; - - for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){ - stg128(base_ptr + tidx * BYTES_PER_LDG, zeros); - } - - return; + for (int tidx = threadIdx.x; tidx < STGS; tidx += THREADS) { + stg128(base_ptr + tidx * BYTES_PER_LDG, zeros); } - // SETUP - const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG; - const char *ptr_in = static_cast(in) + offset_in; + return; + } - const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG; - char *ptr_out = static_cast(out) + OUT_OFFSET_KV_BYTES + offset_out; + // SETUP + const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG; + const char *ptr_in = static_cast(in) + offset_in; - // LOAD + const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG; + char *ptr_out = static_cast(out) + OUT_OFFSET_KV_BYTES + offset_out; - Vec_t local_in[CHUNKS][LDGS]; + // LOAD - #pragma unroll - for( int c = 0; c < CHUNKS; c++ ) { - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA; - local_in[c][l].raw = ldg128(ptr_in + offset); - } + Vec_t local_in[CHUNKS][LDGS]; + +#pragma unroll + for (int c = 0; c < CHUNKS; c++) { +#pragma unroll + for (int l = 0; l < LDGS; l++) { + int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA; + local_in[c][l].raw = ldg128(ptr_in + offset); } + } - // UNPACK - float acc[LDGS][NUM_ELTS]; + // UNPACK + float acc[LDGS][NUM_ELTS]; - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - acc[l][e] = float(local_in[0][l].elt[e]); - } +#pragma unroll + for (int l = 0; l < LDGS; l++) { +#pragma unroll + for (int e = 0; e < NUM_ELTS; e++) { + acc[l][e] = float(local_in[0][l].elt[e]); } - - // COMPUTE - #pragma unroll - for( int c = 1; c < CHUNKS; c++ ) { - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - acc[l][e] += float(local_in[c][l].elt[e]); - } - } + } + +// COMPUTE +#pragma unroll + for (int c = 1; c < CHUNKS; c++) { +#pragma unroll + for (int l = 0; l < LDGS; l++) { +#pragma unroll + for (int e = 0; e < NUM_ELTS; e++) { + acc[l][e] += float(local_in[c][l].elt[e]); + } } + } - // PACK - Vec_t local_out[LDGS]; + // PACK + Vec_t local_out[LDGS]; - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - local_out[l].elt[e] = T(acc[l][e]); - } - } - - // STORE - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - const int offset = l * BYTES_PER_CTA; - stg128(ptr_out + offset, local_out[l].raw); +#pragma unroll + for (int l = 0; l < LDGS; l++) { +#pragma unroll + for (int e = 0; e < NUM_ELTS; e++) { + local_out[l].elt[e] = T(acc[l][e]); } + } + +// STORE +#pragma unroll + for (int l = 0; l < LDGS; l++) { + const int offset = l * BYTES_PER_CTA; + stg128(ptr_out + offset, local_out[l].raw); + } } -void fmha_run_noloop_reduce(void *out, - const void *in, - const int *cu_seqlens, - const int hidden_size, - const int batch_size, - const int total, - const int num_chunks, - cudaStream_t stream) { - - const int blocks = total; - - if(hidden_size == 1024){ - - constexpr int HIDDEN_SIZE = 1024; - constexpr int THREADS = 256; - - if( num_chunks == 2 ) { - fmha_noloop_reduce_kernel<<>>(out, in, cu_seqlens, batch_size); - } else if( num_chunks == 3 ) { - fmha_noloop_reduce_kernel<<>>(out, in, cu_seqlens, batch_size); - } else { - assert(false && "Unsupported num_chunks"); - } - - }else{ - assert(false && "Unsupported hidden_size"); +void fmha_run_noloop_reduce(void *out, const void *in, const int *cu_seqlens, const int hidden_size, + const int batch_size, const int total, const int num_chunks, cudaStream_t stream) { + const int blocks = total; + + if (hidden_size == 1024) { + constexpr int HIDDEN_SIZE = 1024; + constexpr int THREADS = 256; + + if (num_chunks == 2) { + fmha_noloop_reduce_kernel + <<>>(out, in, cu_seqlens, batch_size); + } else if (num_chunks == 3) { + fmha_noloop_reduce_kernel + <<>>(out, in, cu_seqlens, batch_size); + } else { + assert(false && "Unsupported num_chunks"); } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + } else { + assert(false && "Unsupported hidden_size"); + } + + FMHA_CHECK_CUDA(cudaPeekAtLastError()); } diff --git a/apex/contrib/csrc/fmha/src/fmha_utils.h b/apex/contrib/csrc/fmha/src/fmha_utils.h index de07cc78e..e18faa1de 100644 --- a/apex/contrib/csrc/fmha/src/fmha_utils.h +++ b/apex/contrib/csrc/fmha/src/fmha_utils.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -28,25 +28,21 @@ #pragma once #include +#include +#include #include #include -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// -#define FMHA_CHECK_CUDA( call ) \ - do { \ - cudaError_t status_ = call; \ - if( status_ != cudaSuccess ) { \ - fprintf( stderr, \ - "CUDA error (%s:%d): %s\n", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString( status_ ) ); \ - exit( 1 ); \ - } \ - } while( 0 ) +#define FMHA_CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -54,39 +50,38 @@ enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { - if( dtype == DATA_TYPE_FP16 ) { - half x = __float2half_rn( norm ); - uint16_t h = reinterpret_cast( x ); - ushort2 h2 = { h, h }; - alpha = reinterpret_cast( h2 ); - } else if( dtype == DATA_TYPE_FP32 ) { - alpha = reinterpret_cast( norm ); - } else if( dtype == DATA_TYPE_INT32 ) { - int32_t inorm = static_cast( norm ); - alpha = reinterpret_cast( inorm ); - } else { - assert( false ); - } +static inline void set_alpha(uint32_t &alpha, float norm, Data_type dtype) { + if (dtype == DATA_TYPE_FP16) { + half x = __float2half_rn(norm); + uint16_t h = reinterpret_cast(x); + ushort2 h2 = {h, h}; + alpha = reinterpret_cast(h2); + } else if (dtype == DATA_TYPE_FP32) { + alpha = reinterpret_cast(norm); + } else if (dtype == DATA_TYPE_INT32) { + int32_t inorm = static_cast(norm); + alpha = reinterpret_cast(inorm); + } else { + assert(false); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { - switch( dtype ) { +static inline size_t get_size_in_bytes(size_t n, Data_type dtype) { + switch (dtype) { case DATA_TYPE_FP32: - return n * 4; + return n * 4; case DATA_TYPE_FP16: - return n * 2; + return n * 2; case DATA_TYPE_INT32: - return n * 4; + return n * 4; case DATA_TYPE_INT8: - return n; + return n; default: - assert( false ); - return 0; - } + assert(false); + return 0; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index f32b0131b..fe6483fef 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -1,61 +1,38 @@ #include -#include #include +#include // CUDA forward declarations -std::vector focal_loss_forward_cuda( - const at::Tensor &cls_output, - const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, - const int64_t num_real_classes, - const float alpha, - const float gamma, - const float smoothing_factor); +std::vector focal_loss_forward_cuda(const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, + const at::Tensor &num_positives_sum, const int64_t num_real_classes, + const float alpha, const float gamma, const float smoothing_factor); -at::Tensor focal_loss_backward_cuda( - const at::Tensor &grad_output, - const at::Tensor &partial_grad, - const at::Tensor &num_positives_sum); +at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, const at::Tensor &partial_grad, + const at::Tensor &num_positives_sum); // C++ interface #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -std::vector focal_loss_forward( - const at::Tensor &cls_output, - const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, - const int64_t num_real_classes, - const float alpha, - const float gamma, - const float smoothing_factor -) { +std::vector focal_loss_forward(const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, + const at::Tensor &num_positives_sum, const int64_t num_real_classes, + const float alpha, const float gamma, const float smoothing_factor) { CHECK_INPUT(cls_output); CHECK_INPUT(cls_targets_at_level); CHECK_INPUT(num_positives_sum); - return focal_loss_forward_cuda( - cls_output, - cls_targets_at_level, - num_positives_sum, - num_real_classes, - alpha, - gamma, - smoothing_factor); + return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, + smoothing_factor); } -at::Tensor focal_loss_backward( - const at::Tensor &grad_output, - const at::Tensor &partial_grad, - const at::Tensor &num_positives_sum -) { +at::Tensor focal_loss_backward(const at::Tensor &grad_output, const at::Tensor &partial_grad, + const at::Tensor &num_positives_sum) { CHECK_INPUT(grad_output); CHECK_INPUT(partial_grad); @@ -63,10 +40,8 @@ at::Tensor focal_loss_backward( } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &focal_loss_forward, - "Focal loss calculation forward (CUDA)", + m.def("forward", &focal_loss_forward, "Focal loss calculation forward (CUDA)", py::call_guard()); - m.def("backward", &focal_loss_backward, - "Focal loss calculation backward (CUDA)", + m.def("backward", &focal_loss_backward, "Focal loss calculation backward (CUDA)", py::call_guard()); } diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu b/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu index c4776f8f7..8e3d82926 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu @@ -5,23 +5,23 @@ // Use 128-bit vectorization typedef uint4 vector_t; -#define ASSERT_ALIGNED(DTYPE, PTR) \ +#define ASSERT_ALIGNED(DTYPE, PTR) \ TORCH_INTERNAL_ASSERT(is_aligned(PTR), "Tensor " #PTR " is not " #DTYPE " aligned") -template bool is_aligned(const void *ptr) noexcept { +template +bool is_aligned(const void *ptr) noexcept { auto iptr = reinterpret_cast(ptr); return !(iptr % alignof(T)); } -template -__global__ void focal_loss_forward_cuda_kernel( - outscalar_t *loss, scalar_t *partial_grad, - const scalar_t *__restrict__ cls_output, - const labelscalar_t *__restrict__ cls_targets_at_level, - const float *__restrict__ num_positives_sum, const int64_t num_examples, - const int64_t num_classes, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { +template +__global__ void focal_loss_forward_cuda_kernel(outscalar_t *loss, scalar_t *partial_grad, + const scalar_t *__restrict__ cls_output, + const labelscalar_t *__restrict__ cls_targets_at_level, + const float *__restrict__ num_positives_sum, const int64_t num_examples, + const int64_t num_classes, const int64_t num_real_classes, + const float alpha, const float gamma, const float smoothing_factor) { extern __shared__ unsigned char shm[]; accscalar_t *loss_shm = reinterpret_cast(shm); loss_shm[threadIdx.x] = 0; @@ -43,14 +43,14 @@ __global__ void focal_loss_forward_cuda_kernel( vector_t p_vec, grad_vec; // Accumulate loss on each thread - for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; - i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) { + for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; i < num_examples * num_classes; + i += gridDim.x * blockDim.x * ILP) { int64_t idy = i / num_classes; labelscalar_t y = cls_targets_at_level[idy]; int64_t base_yid = i % num_classes; int64_t pos_idx = idy * num_classes + y; - p_vec = *(vector_t *)&cls_output[i]; // Vectorized load + p_vec = *(vector_t *)&cls_output[i]; // Vectorized load // Skip ignored matches if (y == -2) { @@ -131,20 +131,16 @@ __global__ void focal_loss_forward_cuda_kernel( } } -template -__global__ void focal_loss_backward_cuda_kernel( - scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output, - const float *__restrict__ num_positives_sum, const uint64_t numel) { +template +__global__ void focal_loss_backward_cuda_kernel(scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output, + const float *__restrict__ num_positives_sum, const uint64_t numel) { int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; - accscalar_t normalizer = static_cast(grad_output[0]) / - static_cast(num_positives_sum[0]); + accscalar_t normalizer = static_cast(grad_output[0]) / static_cast(num_positives_sum[0]); // The input is enforced to pad to use vector load, thus there's no need to // check whether the last element of ILP can out of bound. - if (idx >= numel) - return; + if (idx >= numel) return; vector_t grad_vec; grad_vec = *(vector_t *)&partial_grad[idx]; @@ -157,30 +153,25 @@ __global__ void focal_loss_backward_cuda_kernel( *(vector_t *)&partial_grad[idx] = grad_vec; } -std::vector focal_loss_forward_cuda( - const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { +std::vector focal_loss_forward_cuda(const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, + const at::Tensor &num_positives_sum, const int64_t num_real_classes, + const float alpha, const float gamma, const float smoothing_factor) { // Checks required for correctness - TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, - "Incorrect number of real classes."); - TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, - "Invalid label type."); - TORCH_INTERNAL_ASSERT( - (num_positives_sum.numel() == 1) && - (num_positives_sum.scalar_type() == at::kFloat), - "Expect num_positives_sum to be a float32 tensor with only one element."); + TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, "Incorrect number of real classes."); + TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, "Invalid label type."); + TORCH_INTERNAL_ASSERT((num_positives_sum.numel() == 1) && (num_positives_sum.scalar_type() == at::kFloat), + "Expect num_positives_sum to be a float32 tensor with only one element."); TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1, - "Mis-matched dimensions between class output and label."); + "Mis-matched dimensions between class output and label."); for (int64_t i = 0; i < cls_targets_at_level.dim(); i++) TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i), - "Mis-matched shape between class output and label."); + "Mis-matched shape between class output and label."); // Checks required for better performance const int ILP = sizeof(vector_t) / cls_output.element_size(); ASSERT_ALIGNED(vector_t, cls_output.data_ptr()); TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0, - "Pad number of classes first to take advantage of vectorized load."); + "Pad number of classes first to take advantage of vectorized load."); TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes."); int64_t num_classes = cls_output.size(-1); @@ -206,49 +197,36 @@ std::vector focal_loss_forward_cuda( // Specialize on label smoothing or not to reduce redundant operations cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (smoothing_factor == 0.0f) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - cls_output.scalar_type(), "focal_loss_fprop", [&] { - using accscalar_t = at::acc_type; - using labelscalar_t = int64_t; - using outscalar_t = float; - const int ILP = sizeof(vector_t) / sizeof(scalar_t); - focal_loss_forward_cuda_kernel - <<>>( - loss.data_ptr(), - partial_grad.data_ptr(), - cls_output.data_ptr(), - cls_targets_at_level.data_ptr(), - num_positives_sum.data_ptr(), num_examples, - num_classes, num_real_classes, alpha, gamma, - smoothing_factor); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), "focal_loss_fprop", [&] { + using accscalar_t = at::acc_type; + using labelscalar_t = int64_t; + using outscalar_t = float; + const int ILP = sizeof(vector_t) / sizeof(scalar_t); + focal_loss_forward_cuda_kernel + <<>>( + loss.data_ptr(), partial_grad.data_ptr(), cls_output.data_ptr(), + cls_targets_at_level.data_ptr(), num_positives_sum.data_ptr(), num_examples, + num_classes, num_real_classes, alpha, gamma, smoothing_factor); + }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - cls_output.scalar_type(), "focal_loss_fprop", [&] { - using accscalar_t = at::acc_type; - using labelscalar_t = int64_t; - using outscalar_t = float; - const int ILP = sizeof(vector_t) / sizeof(scalar_t); - focal_loss_forward_cuda_kernel - <<>>( - loss.data_ptr(), - partial_grad.data_ptr(), - cls_output.data_ptr(), - cls_targets_at_level.data_ptr(), - num_positives_sum.data_ptr(), num_examples, - num_classes, num_real_classes, alpha, gamma, - smoothing_factor); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), "focal_loss_fprop", [&] { + using accscalar_t = at::acc_type; + using labelscalar_t = int64_t; + using outscalar_t = float; + const int ILP = sizeof(vector_t) / sizeof(scalar_t); + focal_loss_forward_cuda_kernel + <<>>( + loss.data_ptr(), partial_grad.data_ptr(), cls_output.data_ptr(), + cls_targets_at_level.data_ptr(), num_positives_sum.data_ptr(), num_examples, + num_classes, num_real_classes, alpha, gamma, smoothing_factor); + }); } AT_CUDA_CHECK(cudaGetLastError()); return {loss, partial_grad}; } -at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, - const at::Tensor &partial_grad, +at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, const at::Tensor &partial_grad, const at::Tensor &num_positives_sum) { // Each thread process ILP elements const int ILP = sizeof(vector_t) / partial_grad.element_size(); @@ -256,17 +234,14 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP)); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - partial_grad.scalar_type(), "focal_loss_bprop", [&] { - using accscalar_t = at::acc_type; - using outscalar_t = float; - const int ILP = sizeof(vector_t) / sizeof(scalar_t); - focal_loss_backward_cuda_kernel - <<>>(partial_grad.data_ptr(), - grad_output.data_ptr(), - num_positives_sum.data_ptr(), - partial_grad.numel()); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(partial_grad.scalar_type(), "focal_loss_bprop", [&] { + using accscalar_t = at::acc_type; + using outscalar_t = float; + const int ILP = sizeof(vector_t) / sizeof(scalar_t); + focal_loss_backward_cuda_kernel + <<>>(partial_grad.data_ptr(), grad_output.data_ptr(), + num_positives_sum.data_ptr(), partial_grad.numel()); + }); AT_CUDA_CHECK(cudaGetLastError()); return partial_grad; diff --git a/apex/contrib/csrc/gpu_direct_storage/gds.cpp b/apex/contrib/csrc/gpu_direct_storage/gds.cpp index 7bb5ce95b..dbd1afc60 100644 --- a/apex/contrib/csrc/gpu_direct_storage/gds.cpp +++ b/apex/contrib/csrc/gpu_direct_storage/gds.cpp @@ -16,29 +16,21 @@ namespace apex::contrib::gds { // POSIX -template < - class T, - typename std::enable_if::value, std::nullptr_t>::type = - nullptr> +template ::value, std::nullptr_t>::type = nullptr> std::string cuFileGetErrorString(T status) { status = std::abs(status); - return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) - : std::string(std::strerror(errno)); + return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) : std::string(std::strerror(errno)); } // CUfileError_t -template < - class T, - typename std::enable_if::value, std::nullptr_t>::type = - nullptr> +template ::value, std::nullptr_t>::type = nullptr> std::string cuFileGetErrorString(T status) { std::string errStr = cuFileGetErrorString(static_cast(status.err)); - if (IS_CUDA_ERR(status)) - errStr.append(".").append(cudaGetErrorString(static_cast(status.cu_err))); + if (IS_CUDA_ERR(status)) errStr.append(".").append(cudaGetErrorString(static_cast(status.cu_err))); return errStr; } -File::File() : is_open(false) {}; +File::File() : is_open(false){}; File::File(const std::string& filename, const std::string& mode) : filename(filename), mode(mode), is_open(false) { open(filename, mode); @@ -61,7 +53,7 @@ void File::open(const std::string& other_filename, const std::string& other_mode maybe_register = true; // Open the binary file - if(mode == "r") { + if (mode == "r") { // for reading fd = ::open(filename.c_str(), O_RDONLY | O_DIRECT); } else if (mode == "w") { @@ -81,26 +73,26 @@ void File::open(const std::string& other_filename, const std::string& other_mode TORCH_CHECK(fd >= 0, "fcntl cannot open file: ", filename); // Register cuFile handle - if(maybe_register) { - memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); - cf_descr.handle.fd = fd; - cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; - status = cuFileHandleRegister(&cf_handle, &cf_descr); - if (status.err != CU_FILE_SUCCESS) { - TORCH_CHECK(false, "cuFileHandleRegister failed: ", cuFileGetErrorString(status)); - } + if (maybe_register) { + memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); + cf_descr.handle.fd = fd; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + status = cuFileHandleRegister(&cf_handle, &cf_descr); + if (status.err != CU_FILE_SUCCESS) { + TORCH_CHECK(false, "cuFileHandleRegister failed: ", cuFileGetErrorString(status)); + } } is_open = true; } void File::close() { // Deregister cuFile handle and close the file - if(is_open) { - if(maybe_register) { - cuFileHandleDeregister(cf_handle); - } - ::close(fd); - fd = -1; + if (is_open) { + if (maybe_register) { + cuFileHandleDeregister(cf_handle); + } + ::close(fd); + fd = -1; } is_open = false; } @@ -136,7 +128,6 @@ void File::save_data(const torch::Tensor& tensor) { TORCH_CHECK(status.err == CU_FILE_SUCCESS, "cuFileBufDeregister failed:", cuFileGetErrorString(status)); } - // Just for benchmarking purposes void File::load_data_no_gds(const torch::Tensor& tensor) { @@ -171,4 +162,4 @@ void File::save_data_no_gds(const torch::Tensor& tensor) { free(dataPtrCPU); } -} // namespace torch_gds +} // namespace apex::contrib::gds diff --git a/apex/contrib/csrc/gpu_direct_storage/gds.h b/apex/contrib/csrc/gpu_direct_storage/gds.h index efa821e38..b1cd448e8 100644 --- a/apex/contrib/csrc/gpu_direct_storage/gds.h +++ b/apex/contrib/csrc/gpu_direct_storage/gds.h @@ -2,35 +2,36 @@ #pragma once -#include #include #include +#include + namespace apex::contrib::gds { - class File { - public: - File(); - File(const std::string& filename, const std::string& mode); - ~File(); - - void open(const std::string& filename, const std::string& mode); - void close(); - - void load_data(const torch::Tensor& tensor); - void save_data(const torch::Tensor& tensor); - void load_data_no_gds(const torch::Tensor& tensor); - void save_data_no_gds(const torch::Tensor& tensor); - - private: - std::string filename; - std::string mode; - - CUfileDescr_t cf_descr; - CUfileHandle_t cf_handle; - CUfileError_t status; - - int fd = -1; - bool is_open = false; - bool maybe_register = true; - }; -} +class File { + public: + File(); + File(const std::string& filename, const std::string& mode); + ~File(); + + void open(const std::string& filename, const std::string& mode); + void close(); + + void load_data(const torch::Tensor& tensor); + void save_data(const torch::Tensor& tensor); + void load_data_no_gds(const torch::Tensor& tensor); + void save_data_no_gds(const torch::Tensor& tensor); + + private: + std::string filename; + std::string mode; + + CUfileDescr_t cf_descr; + CUfileHandle_t cf_handle; + CUfileError_t status; + + int fd = -1; + bool is_open = false; + bool maybe_register = true; +}; +} // namespace apex::contrib::gds diff --git a/apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp b/apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp index e57ccc6a3..f3f72d65a 100644 --- a/apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp +++ b/apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp @@ -3,20 +3,18 @@ #include #include #include + #include -//python bindings +// python bindings PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - py::class_< - apex::contrib::gds::File, - std::shared_ptr>( - m, "_GDSFile") - .def(py::init<>()) - .def(py::init()) - .def("open", &apex::contrib::gds::File::open) - .def("close", &apex::contrib::gds::File::close) - .def("load_data", &apex::contrib::gds::File::load_data) - .def("save_data", &apex::contrib::gds::File::save_data) - .def("load_data_no_gds", &apex::contrib::gds::File::load_data_no_gds) - .def("save_data_no_gds", &apex::contrib::gds::File::save_data_no_gds); + py::class_>(m, "_GDSFile") + .def(py::init<>()) + .def(py::init()) + .def("open", &apex::contrib::gds::File::open) + .def("close", &apex::contrib::gds::File::close) + .def("load_data", &apex::contrib::gds::File::load_data) + .def("save_data", &apex::contrib::gds::File::save_data) + .def("load_data_no_gds", &apex::contrib::gds::File::load_data_no_gds) + .def("save_data_no_gds", &apex::contrib::gds::File::save_data_no_gds); } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc.cpp b/apex/contrib/csrc/group_norm/group_norm_nhwc.cpp index 6a024f370..f75c1ea92 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc.cpp +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc.cpp @@ -2,45 +2,40 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include -#include -#include #include #include +#include +#include +#include #include +#include + #include template -float inline unpack(const T& x) { +float inline unpack(const T &x) { return {}; } template <> -float inline unpack(const __half& x) { +float inline unpack(const __half &x) { return __half2float(x); } - template <> -float inline unpack(const __nv_bfloat16& x) { +float inline unpack(const __nv_bfloat16 &x) { return __bfloat162float(x); } template <> -float inline unpack(const float& x) { +float inline unpack(const float &x) { return x; } //////////////////////////////////////////////////////////////////////////////////////////////////// template -void check_results(const char *name, - const T *out, - const T *ref, - size_t elts, - float tol) { - +void check_results(const char *name, const T *out, const T *ref, size_t elts, float tol) { // The number of errors. int failed = 0; // The number of infinite value. @@ -55,7 +50,7 @@ void check_results(const char *name, fflush(stdout); // Iterate over the different values. - for( size_t ii = 0; ii < elts; ++ii ) { + for (size_t ii = 0; ii < elts; ++ii) { float a = unpack(out[ii]); float b = unpack(ref[ii]); @@ -68,21 +63,20 @@ void check_results(const char *name, // Is one of the quantities very small? bool is_small = abs_a <= tol || abs_b <= tol || den <= tol; // The error. - float err = is_small ? fabsf(a-b) : fabsf(a-b) / den; + float err = is_small ? fabsf(a - b) : fabsf(a - b) / den; // Is the result ok? bool ok = !isnan(a) && !isnan(b) && err <= tol; // Print the error. - if( !ok && (failed < 10 || err > max_err) ) { - + if (!ok && (failed < 10 || err > max_err)) { fprintf(stderr, ">> invalid result for ii=%lu:\n", ii); if (std::is_same::value || std::is_same::value) { // The data. - fprintf(stderr, ">> found...: 0x%04x (%10.6f)\n", reinterpret_cast(out[ii]), a); - fprintf(stderr, ">> expected: 0x%04x (%10.6f)\n", reinterpret_cast(ref[ii]), b); + fprintf(stderr, ">> found...: 0x%04x (%10.6f)\n", reinterpret_cast(out[ii]), a); + fprintf(stderr, ">> expected: 0x%04x (%10.6f)\n", reinterpret_cast(ref[ii]), b); } else if (std::is_same::value) { - fprintf(stderr, ">> found...: 0x%08x (%10.6f)\n", reinterpret_cast(a), a); - fprintf(stderr, ">> expected: 0x%08x (%10.6f)\n", reinterpret_cast(b), b); + fprintf(stderr, ">> found...: 0x%08x (%10.6f)\n", reinterpret_cast(a), a); + fprintf(stderr, ">> expected: 0x%08x (%10.6f)\n", reinterpret_cast(b), b); } else { fprintf(stderr, "\e[1;34mUnknown type of check_results\e[0m\n"); exit(1); @@ -99,13 +93,13 @@ void check_results(const char *name, max_err = fmaxf(max_err, err); // Accumulate the sum. - sum_err = sum_err + (double) err; + sum_err = sum_err + (double)err; infs += !isfinite(a); infs += !isfinite(b); } - if( !failed && infs < 10 ) { + if (!failed && infs < 10) { printf("\e[1;32mcheck........................: OK\e[0m\n"); } else { printf("\e[1;31mcheck........................: FAILED\e[0m\n"); @@ -113,7 +107,7 @@ void check_results(const char *name, printf("tested.......................: %lu\n", elts); printf("failures.....................: %d\n", failed); - printf("failure rate.................: %.2lf%%\n", (double) failed * 100.0 / (double) elts); + printf("failure rate.................: %.2lf%%\n", (double)failed * 100.0 / (double)elts); printf("infs.........................: %d\n", infs); printf("tolerance....................: %.8f\n", tol); printf("\n"); @@ -122,56 +116,41 @@ void check_results(const char *name, printf("max. value...................: %.6f\n", max_val); printf("max. error...................: %.6f\n", max_err); printf("sum. error...................: %.6lf\n", sum_err); - printf("avg. error...................: %.6lf\n", sum_err / (double) elts); + printf("avg. error...................: %.6lf\n", sum_err / (double)elts); printf("\n"); } template void check_results(const char *name, const __half *out, const __half *ref, size_t elts, float tol); -template void check_results(const char *name, const __nv_bfloat16 *out, const __nv_bfloat16 *ref, size_t elts, float tol); +template void check_results(const char *name, const __nv_bfloat16 *out, const __nv_bfloat16 *ref, size_t elts, + float tol); template void check_results(const char *name, const float *out, const float *ref, size_t elts, float tol); //////////////////////////////////////////////////////////////////////////////////////////////////// -static void group_norm_nhwc_bwd_(void *dx_h, - float *dgamma_h, - float *dbeta_h, - const void *dy_h, - const void *x_h, - const float *gamma_h, - const float *beta_h, - const float2 *sums_h, - float epsilon, - int n, - int h, - int w, - int c, - int groups, - bool with_swish, - bool use_fp32, - bool use_bf16) { - +static void group_norm_nhwc_bwd_(void *dx_h, float *dgamma_h, float *dbeta_h, const void *dy_h, const void *x_h, + const float *gamma_h, const float *beta_h, const float2 *sums_h, float epsilon, int n, + int h, int w, int c, int groups, bool with_swish, bool use_fp32, bool use_bf16) { // The number of channels in each group. int channels_per_group = c / groups; // The normalization term to compute the means. - float rcp_hwc_per_group = 1.f / (float) (h * w * channels_per_group); + float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group); // The array to compute gamma. - float *dgamma = (float*) malloc(c * sizeof(float)); + float *dgamma = (float *)malloc(c * sizeof(float)); // The array to compute beta. - float *dbeta = (float*) malloc(c * sizeof(float)); + float *dbeta = (float *)malloc(c * sizeof(float)); // Set gamma/beta to 0. memset(dgamma, 0, c * sizeof(float)); - memset(dbeta, 0, c * sizeof(float)); + memset(dbeta, 0, c * sizeof(float)); // Normalize the activations. - for( int ni = 0; ni < n; ++ni ) { - for( int gi = 0; gi < groups; ++gi ) { - + for (int ni = 0; ni < n; ++ni) { + for (int gi = 0; gi < groups; ++gi) { // The sums from the fwd pass. - float2 sums = sums_h[ni*groups + gi]; + float2 sums = sums_h[ni * groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = sums.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). @@ -187,31 +166,30 @@ static void group_norm_nhwc_bwd_(void *dx_h, float mean_1 = 0.f, mean_2 = 0.f; // Iterate over the activations in the group. - for( int hi = 0; hi < h; ++hi ) { - for( int wi = 0; wi < w; ++wi ) { - for( int ii = 0; ii < channels_per_group; ++ii ) { - + for (int hi = 0; hi < h; ++hi) { + for (int wi = 0; wi < w; ++wi) { + for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. - size_t offset = (size_t) ni*h*w*c + (size_t) hi*w*c + (size_t) wi*c + (size_t) ci; + size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Convert the element at that position to float. float x; if (use_fp32) { - x = reinterpret_cast(x_h)[offset]; + x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { - x = __bfloat162float(reinterpret_cast(x_h)[offset]); + x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { - x = __half2float(reinterpret_cast(x_h)[offset]); + x = __half2float(reinterpret_cast(x_h)[offset]); } // The output. float dy; if (use_fp32) { - dy = reinterpret_cast(dy_h)[offset]; + dy = reinterpret_cast(dy_h)[offset]; } else if (use_bf16) { - dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); + dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); } else { - dy = __half2float(reinterpret_cast(dy_h)[offset]); + dy = __half2float(reinterpret_cast(dy_h)[offset]); } // Gamma. @@ -222,7 +200,7 @@ static void group_norm_nhwc_bwd_(void *dx_h, // Normalize X. float x_norm = x_minus_x_mean * rcp_x_stddev; - if( with_swish ) { + if (with_swish) { // Beta float beta = beta_h[ci]; @@ -244,38 +222,37 @@ static void group_norm_nhwc_bwd_(void *dx_h, mean_1 += x_norm * dx_norm; mean_2 += dx_norm; - } // ii - } // wi - } // hi + } // ii + } // wi + } // hi mean_1 *= rcp_hwc_per_group; mean_2 *= rcp_hwc_per_group; // Iterate over the activations in the group. - for( int hi = 0; hi < h; ++hi ) { - for( int wi = 0; wi < w; ++wi ) { - for( int ii = 0; ii < channels_per_group; ++ii ) { - + for (int hi = 0; hi < h; ++hi) { + for (int wi = 0; wi < w; ++wi) { + for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. - size_t offset = (size_t) ni*h*w*c + (size_t) hi*w*c + (size_t) wi*c + (size_t) ci; + size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; float x; if (use_fp32) { - x = reinterpret_cast(x_h)[offset]; + x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { - x = __bfloat162float(reinterpret_cast(x_h)[offset]); + x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { - x = __half2float(reinterpret_cast(x_h)[offset]); + x = __half2float(reinterpret_cast(x_h)[offset]); } // The output. float dy; if (use_fp32) { - dy = reinterpret_cast(dy_h)[offset]; + dy = reinterpret_cast(dy_h)[offset]; } else if (use_bf16) { - dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); + dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); } else { - dy = __half2float(reinterpret_cast(dy_h)[offset]); + dy = __half2float(reinterpret_cast(dy_h)[offset]); } // Gamma. @@ -286,7 +263,7 @@ static void group_norm_nhwc_bwd_(void *dx_h, // Normalize X. float x_norm = x_minus_x_mean * rcp_x_stddev; - if( with_swish ) { + if (with_swish) { // Beta float beta = beta_h[ci]; @@ -303,24 +280,24 @@ static void group_norm_nhwc_bwd_(void *dx_h, // Set the output gradient. if (use_fp32) { - reinterpret_cast(dx_h)[offset] = dx; + reinterpret_cast(dx_h)[offset] = dx; } else if (use_bf16) { - reinterpret_cast<__nv_bfloat16*>(dx_h)[offset] = __float2bfloat16_rn(dx); + reinterpret_cast<__nv_bfloat16 *>(dx_h)[offset] = __float2bfloat16_rn(dx); } else { - reinterpret_cast<__half*>(dx_h)[offset] = __float2half_rn(dx); + reinterpret_cast<__half *>(dx_h)[offset] = __float2half_rn(dx); } - } // ii - } // wi - } // hi + } // ii + } // wi + } // hi - } // gi - } // ni + } // gi + } // ni // Store gamma/beta. - for( int ci = 0; ci < c; ++ci ) { + for (int ci = 0; ci < c; ++ci) { dgamma_h[ci] = dgamma[ci]; - dbeta_h [ci] = dbeta [ci]; + dbeta_h[ci] = dbeta[ci]; } // Release temporary memory. @@ -330,59 +307,46 @@ static void group_norm_nhwc_bwd_(void *dx_h, //////////////////////////////////////////////////////////////////////////////////////////////////// -static void group_norm_nhwc_fwd_(void *y_h, - const void *x_h, - const float *gamma_h, - const float *beta_h, - float epsilon, - int n, - int h, - int w, - int c, - int groups, - bool with_swish, - bool use_fp32, +static void group_norm_nhwc_fwd_(void *y_h, const void *x_h, const float *gamma_h, const float *beta_h, float epsilon, + int n, int h, int w, int c, int groups, bool with_swish, bool use_fp32, bool use_bf16) { - // The number of channels in each group. int channels_per_group = c / groups; // The normalization term to compute the means. - float inv_hwcg = 1.f / (float) (h * w * channels_per_group); + float inv_hwcg = 1.f / (float)(h * w * channels_per_group); // Normalize the activations. - for( int ni = 0; ni < n; ++ni ) { - for( int gi = 0; gi < groups; ++gi ) { - + for (int ni = 0; ni < n; ++ni) { + for (int gi = 0; gi < groups; ++gi) { // The sums to compute the mean/variance for that group. float sum = 0.f, sum_sq = 0.f; // Iterate over the activations in the group. - for( int hi = 0; hi < h; ++hi ) { - for( int wi = 0; wi < w; ++wi ) { - for( int ii = 0; ii < channels_per_group; ++ii ) { - + for (int hi = 0; hi < h; ++hi) { + for (int wi = 0; wi < w; ++wi) { + for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. - size_t offset = (size_t) ni*h*w*c + (size_t) hi*w*c + (size_t) wi*c + (size_t) ci; + size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Convert the element at that position to float. float x; if (use_fp32) { - x = reinterpret_cast(x_h)[offset]; + x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { - x = __bfloat162float(reinterpret_cast(x_h)[offset]); + x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { - x = __half2float(reinterpret_cast(x_h)[offset]); + x = __half2float(reinterpret_cast(x_h)[offset]); } // Update the sums. sum += x; sum_sq += x * x; - } // ii - } // wi - } // hi + } // ii + } // wi + } // hi // Compute the mean. float mean = sum * inv_hwcg; @@ -394,54 +358,53 @@ static void group_norm_nhwc_fwd_(void *y_h, float inv_stddev = var <= 0.f ? 1.f : (1.f / sqrtf(var + epsilon)); // Iterate over the data to normalize the output. - for( int hi = 0; hi < h; ++hi ) { - for( int wi = 0; wi < w; ++wi ) { - for( int ii = 0; ii < channels_per_group; ++ii ) { - + for (int hi = 0; hi < h; ++hi) { + for (int wi = 0; wi < w; ++wi) { + for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. - size_t offset = (size_t) ni*h*w*c + (size_t) hi*w*c + (size_t) wi*c + (size_t) ci; + size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Normalize. float x; if (use_fp32) { - x = reinterpret_cast(x_h)[offset]; + x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { - x = __bfloat162float(reinterpret_cast(x_h)[offset]); + x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { - x = __half2float(reinterpret_cast(x_h)[offset]); + x = __half2float(reinterpret_cast(x_h)[offset]); } float y = (x - mean) * inv_stddev; // Scale with gamma and add beta. y = y * gamma_h[ci] + beta_h[ci]; // Apply swish (if needed). - if( with_swish ) { + if (with_swish) { y = y * sigmoid(y); } // Store the result. if (use_fp32) { - reinterpret_cast(y_h)[offset] = y; + reinterpret_cast(y_h)[offset] = y; } else if (use_bf16) { - reinterpret_cast<__nv_bfloat16*>(y_h)[offset] = __float2bfloat16_rn(y); + reinterpret_cast<__nv_bfloat16 *>(y_h)[offset] = __float2bfloat16_rn(y); } else { - reinterpret_cast<__half*>(y_h)[offset] = __float2half_rn(y); + reinterpret_cast<__half *>(y_h)[offset] = __float2half_rn(y); } - } // ii - } // wi - } // hi - } // gi - } // ni + } // ii + } // wi + } // hi + } // gi + } // ni } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void random_data(T *dst_h, size_t n, bool use_1s, int range = 3) { - for( size_t ii = 0; ii < n; ++ii ) { + for (size_t ii = 0; ii < n; ++ii) { float x = 1.f; - if( !use_1s ) { - x = (float) (rand() % range - (range / 2)); + if (!use_1s) { + x = (float)(rand() % range - (range / 2)); } if (std::is_same::value) { dst_h[ii] = __float2half_rn(x); @@ -469,7 +432,6 @@ enum class Mode { FWD_INFERENCE, FWD_TRAINING, BWD }; //////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char **argv) { - // The tensor size. int n = 2, h = 64, w = 64, c = 320, groups = 32; // The default mode is inference. @@ -496,48 +458,48 @@ int main(int argc, char **argv) { bool use_bf16 = false; // Parse the parameters. - for( int ii = 1; ii < argc; ++ii ) { - if( !strcmp(argv[ii], "-1s") ) { + for (int ii = 1; ii < argc; ++ii) { + if (!strcmp(argv[ii], "-1s")) { use_1s = true; - } else if( !strcmp(argv[ii], "-bwd") ) { + } else if (!strcmp(argv[ii], "-bwd")) { mode = Mode::BWD; - } else if( !strcmp(argv[ii], "-c") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-c") && ++ii < argc) { c = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-epsilon") && ++ii < argc ) { - epsilon = (float) strtod(argv[ii], nullptr); - } else if( !strcmp(argv[ii], "-fwd") ) { + } else if (!strcmp(argv[ii], "-epsilon") && ++ii < argc) { + epsilon = (float)strtod(argv[ii], nullptr); + } else if (!strcmp(argv[ii], "-fwd")) { mode = Mode::FWD_INFERENCE; - } else if( !strcmp(argv[ii], "-fwd-tr") ) { + } else if (!strcmp(argv[ii], "-fwd-tr")) { mode = Mode::FWD_TRAINING; - } else if( !strcmp(argv[ii], "-groups") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-groups") && ++ii < argc) { groups = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-h") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-h") && ++ii < argc) { h = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-n") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-n") && ++ii < argc) { n = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-one-pass") ) { + } else if (!strcmp(argv[ii], "-one-pass")) { use_one_pass = true; - } else if( !strcmp(argv[ii], "-runs") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-runs") && ++ii < argc) { runs = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-skip-checks") ) { + } else if (!strcmp(argv[ii], "-skip-checks")) { skip_checks = true; - } else if( !strcmp(argv[ii], "-tol") && ++ii < argc ) { - tol = (float) strtod(argv[ii], nullptr); - } else if( !strcmp(argv[ii], "-w") && ++ii < argc ) { + } else if (!strcmp(argv[ii], "-tol") && ++ii < argc) { + tol = (float)strtod(argv[ii], nullptr); + } else if (!strcmp(argv[ii], "-w") && ++ii < argc) { w = strtol(argv[ii], nullptr, 10); - } else if( !strcmp(argv[ii], "-with-swish") ) { + } else if (!strcmp(argv[ii], "-with-swish")) { with_swish = true; - } else if( !strcmp(argv[ii], "-csv") ) { + } else if (!strcmp(argv[ii], "-csv")) { csv_output = true; - } else if( !strcmp(argv[ii], "-fp32") ) { + } else if (!strcmp(argv[ii], "-fp32")) { use_fp32 = true; - } else if( !strcmp(argv[ii], "-bf16") ) { + } else if (!strcmp(argv[ii], "-bf16")) { use_bf16 = true; - } else if( ii < argc ) { + } else if (ii < argc) { fprintf(stderr, "Unknown argument: %s\n", argv[ii]); return 1; } else { - fprintf(stderr, "Argument %s requires a value\n", argv[ii-1]); + fprintf(stderr, "Argument %s requires a value\n", argv[ii - 1]); return 1; } } @@ -586,11 +548,11 @@ int main(int argc, char **argv) { printf("epsilon......................: %f\n", epsilon); printf("with swish...................: %s\n", with_swish ? "true" : "false"); printf("channels per group...........: %d\n", c / groups); - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { printf("mode.........................: bwd\n"); - } else if( mode == Mode::FWD_INFERENCE ) { + } else if (mode == Mode::FWD_INFERENCE) { printf("mode.........................: fwd inference\n"); - } else if( mode == Mode::FWD_TRAINING ) { + } else if (mode == Mode::FWD_TRAINING) { printf("mode.........................: fwd training\n"); } else { assert(false); @@ -601,16 +563,16 @@ int main(int argc, char **argv) { // Compute the SOL. double bytes = 0; int32_t io_bytes = use_fp32 ? sizeof(float) : sizeof(__half); - if( mode != Mode::BWD ) { - bytes = (double) n * h * w * c * io_bytes + // src - (double) c * 4 + // gamma - (double) c * 4 + // beta - (double) n * h * w * c * io_bytes; // out + if (mode != Mode::BWD) { + bytes = (double)n * h * w * c * io_bytes + // src + (double)c * 4 + // gamma + (double)c * 4 + // beta + (double)n * h * w * c * io_bytes; // out } else { - bytes = (double) n * h * w * c * io_bytes * 2 + // src, dsrc - (double) c * 4 * 2 + // gamma, dgamma - (double) c * 4 * 2 + // beta, dbeta - (double) n * h * w * c * io_bytes * 1; // dout + bytes = (double)n * h * w * c * io_bytes * 2 + // src, dsrc + (double)c * 4 * 2 + // gamma, dgamma + (double)c * 4 * 2 + // beta, dbeta + (double)n * h * w * c * io_bytes * 1; // dout } double gbytes = bytes * 1.e-9; double dram_sol = gbytes / dram_peak * 1.e3; @@ -624,7 +586,7 @@ int main(int argc, char **argv) { } // The number of elements in the x tensor. The layout is N x H x W x C. - size_t x_elts = (size_t) n * h * w * c; + size_t x_elts = (size_t)n * h * w * c; // The size of the src in bytes. size_t x_sz = x_elts * io_bytes; @@ -634,57 +596,57 @@ int main(int argc, char **argv) { // Allocate src/dst on the device. void *x_d, *y_d; - CHECK_CUDA(cudaMalloc((void**) &x_d, x_sz)); - CHECK_CUDA(cudaMalloc((void**) &y_d, x_sz)); + CHECK_CUDA(cudaMalloc((void **)&x_d, x_sz)); + CHECK_CUDA(cudaMalloc((void **)&y_d, x_sz)); // The number of elements in the gamma/beta array. - size_t gamma_elts = (size_t) c; + size_t gamma_elts = (size_t)c; // The size of the gamma/beta array in bytes. size_t gamma_sz = gamma_elts * sizeof(float); // Allocate gamma/beta on the host. - float *gamma_h = (float*) malloc(gamma_sz); + float *gamma_h = (float *)malloc(gamma_sz); // Allocate gamma/beta on the device. float *gamma_d; - CHECK_CUDA(cudaMalloc((void**) &gamma_d, gamma_sz)); + CHECK_CUDA(cudaMalloc((void **)&gamma_d, gamma_sz)); // Allocate gamma/beta on the host. - float *beta_h = (float*) malloc(gamma_sz); + float *beta_h = (float *)malloc(gamma_sz); // Allocate gamma/beta on the device. float *beta_d; - CHECK_CUDA(cudaMalloc((void**) &beta_d, gamma_sz)); + CHECK_CUDA(cudaMalloc((void **)&beta_d, gamma_sz)); // Allocate the reference on the host (to be computed on the host). void *y_ref_h = nullptr; - if( !skip_checks ) { + if (!skip_checks) { y_ref_h = malloc(x_sz); } // Allocate the src/dst on the host for the gradients (bwd). void *dx_h = nullptr, *dy_h = nullptr; - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { dx_h = malloc(x_sz); dy_h = malloc(x_sz); } // Allocate src/dst on the device. void *dx_d = nullptr, *dy_d = nullptr; - if( mode == Mode::BWD ) { - CHECK_CUDA(cudaMalloc((void**) &dx_d, x_sz)); - CHECK_CUDA(cudaMalloc((void**) &dy_d, x_sz)); + if (mode == Mode::BWD) { + CHECK_CUDA(cudaMalloc((void **)&dx_d, x_sz)); + CHECK_CUDA(cudaMalloc((void **)&dy_d, x_sz)); } // The gradients for gamma and beta on the host. float *dgamma_h = nullptr, *dbeta_h = nullptr; - if( mode == Mode::BWD ) { - dgamma_h = (float*) malloc(gamma_sz); - dbeta_h = (float*) malloc(gamma_sz); + if (mode == Mode::BWD) { + dgamma_h = (float *)malloc(gamma_sz); + dbeta_h = (float *)malloc(gamma_sz); } // The gradients for gamma and beta on the device. float *dgamma_d = nullptr, *dbeta_d = nullptr; - if( mode == Mode::BWD ) { - CHECK_CUDA(cudaMalloc((void**) &dgamma_d, gamma_sz)); - CHECK_CUDA(cudaMalloc((void**) &dbeta_d, gamma_sz)); + if (mode == Mode::BWD) { + CHECK_CUDA(cudaMalloc((void **)&dgamma_d, gamma_sz)); + CHECK_CUDA(cudaMalloc((void **)&dbeta_d, gamma_sz)); } // The number of sums for the bwd pass. @@ -694,81 +656,81 @@ int main(int argc, char **argv) { // The sums for the bwd pass on the host. float2 *sums_h = nullptr; - if( sums_sz > 0 ) { - sums_h = (float2*) malloc(sums_sz); + if (sums_sz > 0) { + sums_h = (float2 *)malloc(sums_sz); } // The sums for the bwd pass on the device. float2 *sums_d = nullptr; - if( sums_sz > 0 ) { - CHECK_CUDA(cudaMalloc((void**) &sums_d, sums_sz)); + if (sums_sz > 0) { + CHECK_CUDA(cudaMalloc((void **)&sums_d, sums_sz)); } // Allocate the reference on the host (to be computed on the host). void *dx_ref_h = nullptr; - if( mode == Mode::BWD && !skip_checks ) { + if (mode == Mode::BWD && !skip_checks) { dx_ref_h = malloc(x_sz); } // Allocate the reference on the host (to be computed on the host). float *dgamma_ref_h = nullptr, *dbeta_ref_h = nullptr; - if( mode == Mode::BWD && !skip_checks ) { - dgamma_ref_h = (float*) malloc(gamma_sz); - dbeta_ref_h = (float*) malloc(gamma_sz); + if (mode == Mode::BWD && !skip_checks) { + dgamma_ref_h = (float *)malloc(gamma_sz); + dbeta_ref_h = (float *)malloc(gamma_sz); } // Generate random input data for the forward pass. if (use_fp32) { - random_data(reinterpret_cast(x_h), x_elts, use_1s); + random_data(reinterpret_cast(x_h), x_elts, use_1s); } else if (use_bf16) { - random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(x_h), x_elts, use_1s); + random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(x_h), x_elts, use_1s); } else { - random_data<__half>(reinterpret_cast<__half*>(x_h), x_elts, use_1s); + random_data<__half>(reinterpret_cast<__half *>(x_h), x_elts, use_1s); } - random_data (gamma_h, gamma_elts, use_1s); - random_data (beta_h, gamma_elts, use_1s); + random_data(gamma_h, gamma_elts, use_1s); + random_data(beta_h, gamma_elts, use_1s); // Generate the gradients for the bwd pass. - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { if (use_fp32) { - random_data(reinterpret_cast(dy_h), x_elts, use_1s); + random_data(reinterpret_cast(dy_h), x_elts, use_1s); } else if (use_bf16) { - random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(dy_h), x_elts, use_1s); + random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(dy_h), x_elts, use_1s); } else { - random_data<__half>(reinterpret_cast<__half*>(dy_h), x_elts, use_1s); + random_data<__half>(reinterpret_cast<__half *>(dy_h), x_elts, use_1s); } } // Precompute the sums (from the fwd pass) for bwd. - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { // Clear the array of sums (all the elements are set to 0.f). memset(sums_h, 0, sums_sz); // The number of channels in each group. int channels_per_group = c / groups; // Iterate over the different groups. - for( int ni = 0; ni < n; ++ni ) { - for( int gi = 0; gi < groups; ++gi ) { - for( int hi = 0; hi < h; ++hi ) { - for( int wi = 0; wi < w; ++wi ) { - for( int ii = 0; ii < channels_per_group; ++ii ) { + for (int ni = 0; ni < n; ++ni) { + for (int gi = 0; gi < groups; ++gi) { + for (int hi = 0; hi < h; ++hi) { + for (int wi = 0; wi < w; ++wi) { + for (int ii = 0; ii < channels_per_group; ++ii) { // The position of the channel. - int ci = gi*channels_per_group + ii; + int ci = gi * channels_per_group + ii; // The offset to the element. - int64_t offset = (int64_t) ni*h*w*c + hi*w*c + wi*c + ci; + int64_t offset = (int64_t)ni * h * w * c + hi * w * c + wi * c + ci; // The element in float. float x; if (use_fp32) { - x = reinterpret_cast(x_h)[offset]; + x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { - x = __bfloat162float(reinterpret_cast<__nv_bfloat16*>(x_h)[offset]); + x = __bfloat162float(reinterpret_cast<__nv_bfloat16 *>(x_h)[offset]); } else { - x = __half2float(reinterpret_cast<__half*>(x_h)[offset]); + x = __half2float(reinterpret_cast<__half *>(x_h)[offset]); } // Update the sums (sum of X and sum of squares). - sums_h[ni*groups + gi].x += x; - sums_h[ni*groups + gi].y += x * x; + sums_h[ni * groups + gi].x += x; + sums_h[ni * groups + gi].y += x * x; } } } @@ -776,9 +738,9 @@ int main(int argc, char **argv) { } // The normalization term to compute the means. - float rcp_hwc_per_group = 1.f / (float) (h * w * channels_per_group); + float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group); // Normalize the sums. - for( int ngi = 0; ngi < n * groups; ++ngi ) { + for (int ngi = 0; ngi < n * groups; ++ngi) { sums_h[ngi].x *= rcp_hwc_per_group; sums_h[ngi].y *= rcp_hwc_per_group; } @@ -786,24 +748,9 @@ int main(int argc, char **argv) { // Compute the golden reference on the host. if (!skip_checks) { - if( mode == Mode::BWD ) { - group_norm_nhwc_bwd_(dx_ref_h, - dgamma_ref_h, - dbeta_ref_h, - dy_h, - x_h, - gamma_h, - beta_h, - sums_h, - epsilon, - n, - h, - w, - c, - groups, - with_swish, - use_fp32, - use_bf16); + if (mode == Mode::BWD) { + group_norm_nhwc_bwd_(dx_ref_h, dgamma_ref_h, dbeta_ref_h, dy_h, x_h, gamma_h, beta_h, sums_h, epsilon, n, h, w, c, + groups, with_swish, use_fp32, use_bf16); } else { group_norm_nhwc_fwd_(y_ref_h, x_h, gamma_h, beta_h, epsilon, n, h, w, c, groups, with_swish, use_fp32, use_bf16); } @@ -811,42 +758,26 @@ int main(int argc, char **argv) { // Copy to the device. CHECK_CUDA(cudaMemcpyAsync(x_d, x_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); - CHECK_CUDA(cudaMemcpyAsync(gamma_d, - gamma_h, - gamma_sz, - cudaMemcpyHostToDevice, - cudaStreamDefault)); - CHECK_CUDA(cudaMemcpyAsync(beta_d, - beta_h, - gamma_sz, - cudaMemcpyHostToDevice, - cudaStreamDefault)); - - if( mode == Mode::BWD ) { - CHECK_CUDA(cudaMemcpyAsync(dy_d, - dy_h, - x_sz, - cudaMemcpyHostToDevice, - cudaStreamDefault)); + CHECK_CUDA(cudaMemcpyAsync(gamma_d, gamma_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); + CHECK_CUDA(cudaMemcpyAsync(beta_d, beta_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); + + if (mode == Mode::BWD) { + CHECK_CUDA(cudaMemcpyAsync(dy_d, dy_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); // // DEBUG. // printf("sums_h[0] = %8.3f, %8.3f\n", sums_h[0].x, sums_h[0].y); // // END OF DEBUG. - CHECK_CUDA(cudaMemcpyAsync(sums_d, - sums_h, - sums_sz, - cudaMemcpyHostToDevice, - cudaStreamDefault)); + CHECK_CUDA(cudaMemcpyAsync(sums_d, sums_h, sums_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); } // Reset the output buffer with garbage to detect invalid results. - if( mode == Mode::BWD ) { - CHECK_CUDA(cudaMemsetAsync(dx_d, 0xdc, x_sz, cudaStreamDefault)); + if (mode == Mode::BWD) { + CHECK_CUDA(cudaMemsetAsync(dx_d, 0xdc, x_sz, cudaStreamDefault)); CHECK_CUDA(cudaMemsetAsync(dgamma_d, 0xdc, gamma_sz, cudaStreamDefault)); - CHECK_CUDA(cudaMemsetAsync(dbeta_d, 0xdc, gamma_sz, cudaStreamDefault)); + CHECK_CUDA(cudaMemsetAsync(dbeta_d, 0xdc, gamma_sz, cudaStreamDefault)); } else { - CHECK_CUDA(cudaMemsetAsync(y_d, 0xdc, x_sz, cudaStreamDefault)); + CHECK_CUDA(cudaMemsetAsync(y_d, 0xdc, x_sz, cudaStreamDefault)); } // Declare the parameters. @@ -866,7 +797,7 @@ int main(int argc, char **argv) { }(); // Initialize the parameters. - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { params_bwd.dx = dx_d; params_bwd.dgamma = dgamma_d; params_bwd.dbeta = dbeta_d; @@ -908,33 +839,22 @@ int main(int argc, char **argv) { // Finalize the parameters. dim3 grid; - if( mode == Mode::BWD && use_one_pass ) { - group_norm_nhwc_bwd_one_pass_setup(params_bwd, - barriers_elts, - red_buffer_elts, - zeroed_red_buffer_elts, - grid, - props); - } else if( mode == Mode::BWD ) { - group_norm_nhwc_bwd_two_passes_setup(params_bwd, - zeroed_red_buffer_elts); - } else if( use_one_pass ) { - group_norm_nhwc_fwd_one_pass_setup(params_fwd, - barriers_elts, - red_buffer_elts, - grid, - props); + if (mode == Mode::BWD && use_one_pass) { + group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props); + } else if (mode == Mode::BWD) { + group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts); + } else if (use_one_pass) { + group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props); } else { - group_norm_nhwc_fwd_two_passes_setup(params_fwd, - zeroed_red_buffer_elts); + group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts); } // The size in bytes for the reduction buffer. size_t red_buffer_sz = red_buffer_elts * sizeof(float); // Allocate on the device. - if( red_buffer_sz > 0 ) { + if (red_buffer_sz > 0) { float **ptr = mode == Mode::BWD ? ¶ms_bwd.red_buffer : ¶ms_fwd.red_buffer; - CHECK_CUDA(cudaMalloc((void**) ptr, red_buffer_sz)); + CHECK_CUDA(cudaMalloc((void **)ptr, red_buffer_sz)); } // The size of the array of barriers. @@ -944,19 +864,19 @@ int main(int argc, char **argv) { // Allocate the buffer if needed. void *zeroed_red_buffer_d_ = nullptr; - if( zeroed_red_buffer_sz > 0 ) { - CHECK_CUDA(cudaMalloc((void**) &zeroed_red_buffer_d_, zeroed_red_buffer_sz)); + if (zeroed_red_buffer_sz > 0) { + CHECK_CUDA(cudaMalloc((void **)&zeroed_red_buffer_d_, zeroed_red_buffer_sz)); } // The buffer of barriers. DO NOT CALL cudaFree on it!!! - int *barriers_d = reinterpret_cast(zeroed_red_buffer_d_); + int *barriers_d = reinterpret_cast(zeroed_red_buffer_d_); // The zeroed red buffer. DO NOT CALL cudaFree on it!!! - float *zeroed_red_buffer_d = reinterpret_cast(&barriers_d[barriers_elts]); + float *zeroed_red_buffer_d = reinterpret_cast(&barriers_d[barriers_elts]); // Must be aligned on 4B for floats. It obviously is (unless someone changes the code ;)). - assert(reinterpret_cast(zeroed_red_buffer_d) % sizeof(float) == 0); + assert(reinterpret_cast(zeroed_red_buffer_d) % sizeof(float) == 0); // Set the barriers if needed. - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { params_bwd.barriers = barriers_d; params_bwd.zeroed_red_buffer = zeroed_red_buffer_d; } else { @@ -971,24 +891,20 @@ int main(int argc, char **argv) { // Time the reference code. CHECK_CUDA(cudaEventRecord(start)); - for( int ii = 0; ii < runs; ++ii ) { - + for (int ii = 0; ii < runs; ++ii) { // Clear the zeroed buffer if needed. - if( zeroed_red_buffer_sz > 0 ) { - CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_, - 0, - zeroed_red_buffer_sz, - cudaStreamDefault)); + if (zeroed_red_buffer_sz > 0) { + CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_, 0, zeroed_red_buffer_sz, cudaStreamDefault)); } - if( use_one_pass && mode == Mode::BWD ) { + if (use_one_pass && mode == Mode::BWD) { group_norm_nhwc_bwd_one_pass_run(params_bwd, grid, cudaStreamDefault); - } else if( use_one_pass ) { + } else if (use_one_pass) { group_norm_nhwc_fwd_one_pass_run(params_fwd, grid, cudaStreamDefault); - } else if( mode == Mode::BWD ) { - group_norm_nhwc_bwd_two_passes_sum (params_bwd, cudaStreamDefault); + } else if (mode == Mode::BWD) { + group_norm_nhwc_bwd_two_passes_sum(params_bwd, cudaStreamDefault); group_norm_nhwc_bwd_two_passes_scale(params_bwd, cudaStreamDefault); } else { - group_norm_nhwc_fwd_two_passes_sum (params_fwd, cudaStreamDefault); + group_norm_nhwc_fwd_two_passes_sum(params_fwd, cudaStreamDefault); group_norm_nhwc_fwd_two_passes_scale(params_fwd, cudaStreamDefault); } } @@ -1000,24 +916,16 @@ int main(int argc, char **argv) { CHECK_CUDA(cudaEventElapsedTime(&elapsed, start, stop)); if (!csv_output) { printf("elapsed......................: %.3fms\n", elapsed); - printf("elapsed per run..............: %.3fms\n", elapsed / (float) runs); + printf("elapsed per run..............: %.3fms\n", elapsed / (float)runs); printf("efficiency...................: %.3lf%%\n", dram_sol * runs / elapsed * 100.0); printf("\n"); } // Copy the results to the host. - if( mode == Mode::BWD ) { + if (mode == Mode::BWD) { CHECK_CUDA(cudaMemcpyAsync(dx_h, dx_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); - CHECK_CUDA(cudaMemcpyAsync(dgamma_h, - dgamma_d, - gamma_sz, - cudaMemcpyDeviceToHost, - cudaStreamDefault)); - CHECK_CUDA(cudaMemcpyAsync(dbeta_h, - dbeta_d, - gamma_sz, - cudaMemcpyDeviceToHost, - cudaStreamDefault)); + CHECK_CUDA(cudaMemcpyAsync(dgamma_h, dgamma_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); + CHECK_CUDA(cudaMemcpyAsync(dbeta_h, dbeta_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); } else { CHECK_CUDA(cudaMemcpyAsync(y_h, y_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); } @@ -1027,33 +935,31 @@ int main(int argc, char **argv) { // Check the results. if (!csv_output) { - if( mode == Mode::BWD && !skip_checks ) { + if (mode == Mode::BWD && !skip_checks) { if (use_fp32) { - check_results("dx", reinterpret_cast(dx_h), - reinterpret_cast(dx_ref_h), x_elts, tol); + check_results("dx", reinterpret_cast(dx_h), reinterpret_cast(dx_ref_h), x_elts, tol); } else if (use_bf16) { - check_results<__nv_bfloat16>("dx", reinterpret_cast<__nv_bfloat16*>(dx_h), - reinterpret_cast<__nv_bfloat16*>(dx_ref_h), x_elts, tol); + check_results<__nv_bfloat16>("dx", reinterpret_cast<__nv_bfloat16 *>(dx_h), + reinterpret_cast<__nv_bfloat16 *>(dx_ref_h), x_elts, tol); } else { - check_results<__half>("dx", reinterpret_cast<__half*>(dx_h), - reinterpret_cast<__half*>(dx_ref_h), x_elts, tol); + check_results<__half>("dx", reinterpret_cast<__half *>(dx_h), reinterpret_cast<__half *>(dx_ref_h), x_elts, + tol); } - check_results ("dgamma", dgamma_h, dgamma_ref_h, gamma_elts, tol); - check_results ("dbeta", dbeta_h, dbeta_ref_h, gamma_elts, tol); - } else if( !skip_checks ) { + check_results("dgamma", dgamma_h, dgamma_ref_h, gamma_elts, tol); + check_results("dbeta", dbeta_h, dbeta_ref_h, gamma_elts, tol); + } else if (!skip_checks) { if (use_fp32) { - check_results("y", reinterpret_cast(y_h), - reinterpret_cast(y_ref_h), x_elts, tol); + check_results("y", reinterpret_cast(y_h), reinterpret_cast(y_ref_h), x_elts, tol); } else if (use_bf16) { - check_results<__nv_bfloat16>("y", reinterpret_cast<__nv_bfloat16*>(y_h), - reinterpret_cast<__nv_bfloat16*>(y_ref_h), x_elts, tol); + check_results<__nv_bfloat16>("y", reinterpret_cast<__nv_bfloat16 *>(y_h), + reinterpret_cast<__nv_bfloat16 *>(y_ref_h), x_elts, tol); } else { - check_results<__half>("y", reinterpret_cast<__half*>(y_h), - reinterpret_cast<__half*>(y_ref_h), x_elts, tol); + check_results<__half>("y", reinterpret_cast<__half *>(y_h), reinterpret_cast<__half *>(y_ref_h), x_elts, tol); } } } else { - printf("%d,%d,%d,%d,%d,%d,%d,%f\n", n, h, w, c, groups, (uint32_t)use_one_pass, (uint32_t)mode, elapsed / (float) runs); + printf("%d,%d,%d,%d,%d,%d,%d,%f\n", n, h, w, c, groups, (uint32_t)use_one_pass, (uint32_t)mode, + elapsed / (float)runs); } // Destroy the cuda events. diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc.h b/apex/contrib/csrc/group_norm/group_norm_nhwc.h old mode 100755 new mode 100644 index eb33c7a83..0b85b3d8c --- a/apex/contrib/csrc/group_norm/group_norm_nhwc.h +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc.h @@ -4,47 +4,43 @@ */ #pragma once +#include +#include +#include #include #include #include #include -#include -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CHECK_CUDA(call) do { \ - cudaError_t status_ = call; \ - if( status_ != cudaSuccess ) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ -} while(0) +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ __host__ int div_up(int m, int n) { - return (m + n-1) / n; -} +static inline __device__ __host__ int div_up(int m, int n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ __host__ float sigmoid(float x) { - return 1.f / (1.f + expf(-x)); -} +static inline __device__ __host__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void spin_wait_(int *barrier, int step, int expected) { - // THE FOLLOWING CODE MUST BE EXECUTED BY A SINGLE THREAD IN THE CTA. // Update the global counter. Make sure prior writes are visible. - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" :: "l"(barrier), "r"(step)); + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster. - for( volatile int found = -1; found != expected; ) { + for (volatile int found = -1; found != expected;) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } @@ -80,9 +76,9 @@ struct Group_sums { struct Group_sums_op { inline __device__ Group_sums operator()(const Group_sums &a, const Group_sums &b) { Group_sums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sum = b.flag ? b.sum : (a.sum + b.sum); dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); - dst.flag = a.flag + b.flag; + dst.flag = a.flag + b.flag; return dst; } }; @@ -90,7 +86,6 @@ struct Group_sums_op { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_norm_nhwc_fwd_params { - // The output buffer. Layout NHWC. void *y; // The sums for the bwd pass. Not written if it is a nullptr. @@ -137,21 +132,19 @@ struct Group_norm_nhwc_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&, - size_t &red_buffer_elts); +void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params &, size_t &red_buffer_elts); //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_sum (const Group_norm_nhwc_fwd_params&, cudaStream_t); +void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params &, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params&, cudaStream_t); +void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params &, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_norm_nhwc_bwd_params { - // The output buffer. Layout NHWC. void *dx; // The output buffer. Layout NHWC. @@ -204,15 +197,14 @@ struct Group_norm_nhwc_bwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params&, - size_t &red_buffer_elts); +void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params &, size_t &red_buffer_elts); //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_sum (const Group_norm_nhwc_bwd_params&, cudaStream_t); +void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params &, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params&, cudaStream_t); +void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params &, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h old mode 100755 new mode 100644 index 5ce4e3ae5..9db812814 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h @@ -2,11 +2,13 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ +#include + +#include + #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -14,71 +16,68 @@ // //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_BWD_SELECT(FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ - GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) \ - { \ - assert(false && "Not implemented"); \ +#define GN_BWD_SELECT(FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ + GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \ + assert(false && "Not implemented"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_BWD_RUNNER_SELECT(function) \ - GN_BWD_SELECT(_run, function) +#define GN_BWD_RUNNER_SELECT(function) GN_BWD_SELECT(_run, function) -#define GN_BWD_BLOCKS_PER_SM_SELECT(function) \ - GN_BWD_SELECT(_blocks_per_sm, function) +#define GN_BWD_BLOCKS_PER_SM_SELECT(function) GN_BWD_SELECT(_blocks_per_sm, function) //////////////////////////////////////////////////////////////////////////////////////////////////// -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) -GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) +GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128) @@ -86,34 +85,29 @@ GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160) //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params ¶ms, - size_t &barriers_elts, - size_t &red_buffer_elts, - size_t &zeroed_red_buffer_elts, - dim3 &grid, +void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params ¶ms, size_t &barriers_elts, + size_t &red_buffer_elts, size_t &zeroed_red_buffer_elts, dim3 &grid, const cudaDeviceProp &props) { - // The pre-computed dimensions. - params.hw = params.h * params.w; + params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. - params.inv_hwc_per_group = 1.f / (float) (params.hw * params.channels_per_group); + params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define how many activations are computed per block. - if( (params.hw >= 1024 && params.channels_per_group >= 80) || - (params.hw >= 256 && params.channels_per_group >= 160) ) - { + if ((params.hw >= 1024 && params.channels_per_group >= 80) || + (params.hw >= 256 && params.channels_per_group >= 160)) { params.acts_per_block = 8 * 16; - } else if( params.hw >= 512 ) { + } else if (params.hw >= 512) { params.acts_per_block = 32 * 16; - } else if( params.hw >= 256 ) { + } else if (params.hw >= 256) { params.acts_per_block = 16 * 16; - } else if( params.hw >= 128 ) { + } else if (params.hw >= 128) { params.acts_per_block = 8 * 16; - } else if ( params.hw > 0 ) { + } else if (params.hw > 0) { params.acts_per_block = 8 * 8; } else { // We should never be here if params are set correctly. @@ -149,7 +143,7 @@ void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params ¶ms, barriers_elts += 1; // The number of elements in the reduction buffer (for the sums and sums of squared). - if( blocks_per_slice == 1 ) { + if (blocks_per_slice == 1) { red_buffer_elts = 0; } else { // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats. @@ -163,13 +157,9 @@ void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params ¶ms, assert(params.channels_per_block % params.channels_per_group == 0); } -inline void group_norm_nhwc_bwd_one_pass_run(const Group_norm_nhwc_bwd_params ¶ms, - const dim3 &grid, - cudaStream_t stream) { - - using Function_t = void (*)(const Group_norm_nhwc_bwd_params &, - const dim3 &, - cudaStream_t); +inline void group_norm_nhwc_bwd_one_pass_run(const Group_norm_nhwc_bwd_params ¶ms, const dim3 &grid, + cudaStream_t stream) { + using Function_t = void (*)(const Group_norm_nhwc_bwd_params &, const dim3 &, cudaStream_t); Function_t runner; GN_BWD_RUNNER_SELECT(runner); diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh old mode 100755 new mode 100644 index 9cc3a0905..e71e834fd --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh @@ -2,21 +2,22 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc.h" -#include "traits.h" #include + #include +#include "group_norm_nhwc.h" +#include "traits.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// // // B A C K W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int ACTS_PER_BLOCK_, int CHANNELS_PER_GROUP_, int THREADS_PER_BLOCK_ > -__global__ __launch_bounds__(THREADS_PER_BLOCK_) - void group_norm_nhwc_bwd_one_pass_kernel(Group_norm_nhwc_bwd_params params) { - +template +__global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_bwd_one_pass_kernel( + Group_norm_nhwc_bwd_params params) { // The IO traits. using Traits = Traits_; // The IO traits. @@ -48,7 +49,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) // The number of activations that are loaded per loop. constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT; // The number of rows per thread. - constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP-1) / ACTS_PER_LOOP; + constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP; // The number of active threads. constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT; @@ -73,8 +74,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) const bool is_active = threadIdx.x < ACTIVE_THREADS; // Iterate over the iterms in the batch. - for( int ngi = blockIdx.y, step = 0; ngi < params.n*params.groups; ngi += gridDim.y, ++step ) { - + for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) { // The instance and the group. TODO: Use fast divmod? int ni = ngi / params.groups; int gi = ngi % params.groups; @@ -91,34 +91,34 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon); // The offset to the first activation loaded by that thread. - const int64_t offset = (int64_t) ni*params.hwc + gi*CHANNELS_PER_GROUP + ci; + const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci; // The pointer to the first activation loaded by that thread. - const IOType *x_ptr = &reinterpret_cast(params.x)[offset]; + const IOType *x_ptr = &reinterpret_cast(params.x)[offset]; // The pointer to the first gradient loaded by that thread. - const IOType *dy_ptr = &reinterpret_cast(params.dy)[offset]; + const IOType *dy_ptr = &reinterpret_cast(params.dy)[offset]; // Load the X and dY into registers. IOType2 x[ACTS_PER_THREAD], dy[ACTS_PER_THREAD]; - #pragma unroll - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { - int hwj = hwi + ii*ACTS_PER_LOOP; - x [ii] = IOTraits::zero(); +#pragma unroll + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { + int hwj = hwi + ii * ACTS_PER_LOOP; + x[ii] = IOTraits::zero(); dy[ii] = IOTraits::zero(); - if( is_active && hwj < params.hw ) { - x [ii] = *reinterpret_cast(&x_ptr [hwj*params.c]); - dy[ii] = *reinterpret_cast(&dy_ptr[hwj*params.c]); + if (is_active && hwj < params.hw) { + x[ii] = *reinterpret_cast(&x_ptr[hwj * params.c]); + dy[ii] = *reinterpret_cast(&dy_ptr[hwj * params.c]); } } // Load gamma as well. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); - if( is_active ) { - gamma_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.gamma)[gi*CHANNELS_PER_GROUP+ci])); + if (is_active) { + gamma_f2 = WTraits::unpack(*reinterpret_cast( + &reinterpret_cast(params.gamma)[gi * CHANNELS_PER_GROUP + ci])); if (params.with_swish) { - beta_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.beta) [gi*CHANNELS_PER_GROUP+ci])); + beta_f2 = WTraits::unpack(*reinterpret_cast( + &reinterpret_cast(params.beta)[gi * CHANNELS_PER_GROUP + ci])); } } @@ -127,11 +127,11 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) // Accumulated gradients for dgrad calculation. float mean_1 = 0.f, mean_2 = 0.f; - // Compute the sum and the sum of squares for each thread. - #pragma unroll - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { +// Compute the sum and the sum of squares for each thread. +#pragma unroll + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Convert x to float. - float2 x_f2 = IOTraits::unpack(x [ii]); + float2 x_f2 = IOTraits::unpack(x[ii]); // Convert dY to float. float2 dy_f2 = IOTraits::unpack(dy[ii]); @@ -175,7 +175,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) // Pack valid gradients. float2 sums = make_float2(0.f, 0.f); - if( ACTIVE_THREADS == THREADS_PER_BLOCK || is_active ) { + if (ACTIVE_THREADS == THREADS_PER_BLOCK || is_active) { sums = make_float2(mean_1, mean_2); } @@ -191,9 +191,9 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) __syncthreads(); // Compute gamma/beta for the block. - if( threadIdx.x < THREADS_PER_ACT ) { - for( int ii = 1; ii < ACTS_PER_LOOP; ++ii ) { - float4 other = smem_dgamma_dbeta[threadIdx.x + ii*THREADS_PER_ACT]; + if (threadIdx.x < THREADS_PER_ACT) { + for (int ii = 1; ii < ACTS_PER_LOOP; ++ii) { + float4 other = smem_dgamma_dbeta[threadIdx.x + ii * THREADS_PER_ACT]; dgamma_dbeta.x += other.x; dgamma_dbeta.y += other.y; dgamma_dbeta.z += other.z; @@ -207,42 +207,41 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) float *red_buffer_dgamma_dbeta = ¶ms.zeroed_red_buffer[cj]; // The first threads store their gradients for gamma/beta. - if( threadIdx.x < THREADS_PER_ACT ) { - atomicAdd(&red_buffer_dgamma_dbeta[0*params.c/2], dgamma_dbeta.x); - atomicAdd(&red_buffer_dgamma_dbeta[1*params.c/2], dgamma_dbeta.y); - atomicAdd(&red_buffer_dgamma_dbeta[2*params.c/2], dgamma_dbeta.z); - atomicAdd(&red_buffer_dgamma_dbeta[3*params.c/2], dgamma_dbeta.w); + if (threadIdx.x < THREADS_PER_ACT) { + atomicAdd(&red_buffer_dgamma_dbeta[0 * params.c / 2], dgamma_dbeta.x); + atomicAdd(&red_buffer_dgamma_dbeta[1 * params.c / 2], dgamma_dbeta.y); + atomicAdd(&red_buffer_dgamma_dbeta[2 * params.c / 2], dgamma_dbeta.z); + atomicAdd(&red_buffer_dgamma_dbeta[3 * params.c / 2], dgamma_dbeta.w); } // The block leader stores to global memory, if needed. - if( gridDim.x > 1 ) { - + if (gridDim.x > 1) { // The index of the buffer. int red_buffer_idx = step & 1; // The barrier. - int *barrier = ¶ms.barriers[red_buffer_idx*gridDim.y + blockIdx.y]; + int *barrier = ¶ms.barriers[red_buffer_idx * gridDim.y + blockIdx.y]; // The offset to the reduction buffer. - int red_buffer_offset = red_buffer_idx*gridDim.x*gridDim.y*2; + int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2; // The reduction buffer. - float2 *red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); + float2 *red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); // The offset to the reduction buffer for dgamma/dbeta. // The first thread stores its sums. - if( threadIdx.x == 0 ) { - red_buffer[blockIdx.x*gridDim.y + blockIdx.y] = sums; + if (threadIdx.x == 0) { + red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums; } // Make sure the data is in memory. - if( threadIdx.x == 0 ) { + if (threadIdx.x == 0) { spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x); } __syncthreads(); // Update the sums. - for( int ii = 0; ii < gridDim.x; ++ii ) { - if( ii != blockIdx.x && threadIdx.x == 0 ) { - float2 other_sums = red_buffer[ii*gridDim.y + blockIdx.y]; + for (int ii = 0; ii < gridDim.x; ++ii) { + if (ii != blockIdx.x && threadIdx.x == 0) { + float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y]; sums.x += other_sums.x; sums.y += other_sums.y; } @@ -250,7 +249,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) } // Store the result for other threads. - if( threadIdx.x == 0 ) { + if (threadIdx.x == 0) { smem_sums = sums; } @@ -266,12 +265,12 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) mean_2 *= params.inv_hwc_per_group; // The pointer to the first activation stored by that thread. - IOType *dx_ptr = &reinterpret_cast(params.dx)[offset]; + IOType *dx_ptr = &reinterpret_cast(params.dx)[offset]; // Iterate over the activations to normalize the activations and store the results. - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Convert x to float. - float2 x_f2 = IOTraits::unpack(x [ii]); + float2 x_f2 = IOTraits::unpack(x[ii]); // Convert dY to float. float2 dy_f2 = IOTraits::unpack(dy[ii]); @@ -304,49 +303,49 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev; // Store the scaled values. - int hwj = hwi + ii*ACTS_PER_LOOP; - if( is_active && hwj < params.hw ) { - *reinterpret_cast(&dx_ptr[hwj*params.c]) = IOTraits::pack(dx); + int hwj = hwi + ii * ACTS_PER_LOOP; + if (is_active && hwj < params.hw) { + *reinterpret_cast(&dx_ptr[hwj * params.c]) = IOTraits::pack(dx); } } } // The completion barrier. - int *barrier = ¶ms.barriers[gridDim.x == 1 ? 0 : gridDim.y*2]; + int *barrier = ¶ms.barriers[gridDim.x == 1 ? 0 : gridDim.y * 2]; // Mark the completion of the threadblock. - if( threadIdx.x == 0 ) { - asm volatile("red.release.gpu.global.add.s32 [%0], 1;" :: "l"(barrier)); + if (threadIdx.x == 0) { + asm volatile("red.release.gpu.global.add.s32 [%0], 1;" ::"l"(barrier)); } // Exit if that's not the last thread block. - if( blockIdx.x != gridDim.x-1 || blockIdx.y != gridDim.y-1 ) { + if (blockIdx.x != gridDim.x - 1 || blockIdx.y != gridDim.y - 1) { return; } // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster. - if( threadIdx.x == 0 ) { - for( int found = -1; found != gridDim.x * gridDim.y; ) { + if (threadIdx.x == 0) { + for (int found = -1; found != gridDim.x * gridDim.y;) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } __syncthreads(); // The last block converts dgamma and dbeta to half. - for( int idx = threadIdx.x; idx < params.c/2; idx += THREADS_PER_BLOCK ) { + for (int idx = threadIdx.x; idx < params.c / 2; idx += THREADS_PER_BLOCK) { // Load dgamma. float2 dgamma; - dgamma.x = params.zeroed_red_buffer[idx + 0*params.c/2]; - dgamma.y = params.zeroed_red_buffer[idx + 1*params.c/2]; + dgamma.x = params.zeroed_red_buffer[idx + 0 * params.c / 2]; + dgamma.y = params.zeroed_red_buffer[idx + 1 * params.c / 2]; // Load dbeta. float2 dbeta; - dbeta.x = params.zeroed_red_buffer[idx + 2*params.c/2]; - dbeta.y = params.zeroed_red_buffer[idx + 3*params.c/2]; + dbeta.x = params.zeroed_red_buffer[idx + 2 * params.c / 2]; + dbeta.y = params.zeroed_red_buffer[idx + 3 * params.c / 2]; // Store to global memory. - *reinterpret_cast(&reinterpret_cast(params.dgamma)[idx*2]) = WTraits::pack(dgamma); - *reinterpret_cast(&reinterpret_cast(params.dbeta )[idx*2]) = WTraits::pack(dbeta); + *reinterpret_cast(&reinterpret_cast(params.dgamma)[idx * 2]) = WTraits::pack(dgamma); + *reinterpret_cast(&reinterpret_cast(params.dbeta)[idx * 2]) = WTraits::pack(dbeta); } } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu old mode 100755 new mode 100644 index 46150a510..56594f1d7 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu @@ -2,11 +2,13 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ +#include + +#include + #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -14,9 +16,8 @@ // //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int THREADS_PER_BLOCK > +template __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params) { - // The IO traits. using Traits = Traits_; // The IO traits. @@ -50,7 +51,7 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params int gi = ci / params.channels_per_group; // The sums from the fwd pass. - float2 fwd = params.sums[ni*params.groups + gi]; + float2 fwd = params.sums[ni * params.groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = fwd.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). @@ -63,12 +64,10 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params // Load gamma. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); - if( ci < params.c ) { - gamma_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.gamma)[ci])); + if (ci < params.c) { + gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); if (params.with_swish) { - beta_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.beta)[ci])); + beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } } @@ -79,7 +78,7 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // The gradients for gamma/beta. float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f); @@ -87,21 +86,20 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params float mean_1 = 0.f, mean_2 = 0.f; // Iterate over the activations to compute the sums. - for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { - + for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The offset. - int64_t offset = (int64_t) ni*params.hwc + hwi*params.c + ci; + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 x_v2 = IOTraits::zero(); IOType2 dy_v2 = IOTraits::zero(); - if( ci < params.c ) { - x_v2 = *reinterpret_cast(&reinterpret_cast(params.x )[offset]); - dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); + if (ci < params.c) { + x_v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); + dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); } // Extract the two half values. - float2 x_f2 = IOTraits::unpack(x_v2); + float2 x_f2 = IOTraits::unpack(x_v2); float2 dy_f2 = IOTraits::unpack(dy_v2); // X - X_mean. @@ -143,14 +141,14 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params } // The data for the summations. - Group_sums inp {cj == 0 ? 1 : 0, mean_1, mean_2}; + Group_sums inp{cj == 0 ? 1 : 0, mean_1, mean_2}; // Do the segmented scan. Group_sums out; Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op()); // Store the results for the groups in shared memory (to produce coalesced stores later). - if( cj == params.channels_per_group - 2 /* 2 channels per thread */ ) { + if (cj == params.channels_per_group - 2 /* 2 channels per thread */) { smem[gj] = make_float2(out.sum, out.sum_sq); } @@ -164,45 +162,43 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params float2 sums = smem[threadIdx.x]; // Store to global memory. - if( threadIdx.x < params.groups_per_block && gk < params.groups ) { - atomicAdd(¶ms.zeroed_red_buffer[(2*ni+0)*params.groups + gk], sums.x); - atomicAdd(¶ms.zeroed_red_buffer[(2*ni+1)*params.groups + gk], sums.y); + if (threadIdx.x < params.groups_per_block && gk < params.groups) { + atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x); + atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y); } // The base pointer for the gradients for gamma and beta. - float *dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n*params.groups*2]; + float *dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n * params.groups * 2]; // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved. int ck = blockIdx.x * params.channels_per_block + threadIdx.x; // Store dgamma and dbeta as well. - if( ck < params.c ) { - atomicAdd(&dgamma_beta_ptr[0*params.c + 0*blockDim.x + ck], dgamma.x); - atomicAdd(&dgamma_beta_ptr[0*params.c + 1*blockDim.x + ck], dgamma.y); - atomicAdd(&dgamma_beta_ptr[1*params.c + 0*blockDim.x + ck], dbeta .x); - atomicAdd(&dgamma_beta_ptr[1*params.c + 1*blockDim.x + ck], dbeta .y); + if (ck < params.c) { + atomicAdd(&dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck], dgamma.x); + atomicAdd(&dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck], dgamma.y); + atomicAdd(&dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck], dbeta.x); + atomicAdd(&dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck], dbeta.y); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, - size_t &zeroed_red_buffer_elts) { - +void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, size_t &zeroed_red_buffer_elts) { // The pre-computed dimensions. - params.hw = params.h * params.w; + params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. - params.inv_hwc_per_group = 1.f / (float) (params.hw * params.channels_per_group); + params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; - if( params.c >= 1280 ) { + if (params.c >= 1280) { blocks_per_act_slice = 128 / params.n; - } else if( params.c >= 640 ) { + } else if (params.c >= 640) { blocks_per_act_slice = 256 / params.n; } else { blocks_per_act_slice = 512 / params.n; @@ -217,15 +213,14 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, // The number of channels per block. params.channels_per_block = 320; // Special case to deal with 30 channels per group. - if( params.channels_per_block % params.channels_per_group != 0 ) { + if (params.channels_per_block % params.channels_per_group != 0) { params.channels_per_block = 240; } // Special case to deal with 70 channels per group. - if( params.c == 2240 ) { + if (params.c == 2240) { params.channels_per_block = 280; - } - else if (params.c == 832){ + } else if (params.c == 832) { params.channels_per_block = 208; } @@ -261,9 +256,7 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params ¶ms, - cudaStream_t stream) { - +void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. dim3 grid; @@ -300,9 +293,8 @@ void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params ¶ms //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int THREADS_PER_BLOCK > +template __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params params) { - // The IO traits. using Traits = Traits_; // The IO traits. @@ -329,13 +321,13 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para // Load the gradients for the group. float mean_1 = 0.f, mean_2 = 0.f; - if( gi < params.groups ) { - mean_1 = params.zeroed_red_buffer[(2*ni+0)*params.groups + gi]; - mean_2 = params.zeroed_red_buffer[(2*ni+1)*params.groups + gi]; + if (gi < params.groups) { + mean_1 = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi]; + mean_2 = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi]; } // The sums from the fwd pass. - float2 fwd = params.sums[ni*params.groups + gi]; + float2 fwd = params.sums[ni * params.groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = fwd.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). @@ -352,36 +344,33 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para // Load gamma. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); - if( ci < params.c ) { - gamma_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.gamma)[ci])); + if (ci < params.c) { + gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); if (params.with_swish) { - beta_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.beta)[ci])); + beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } } // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. - for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { - + for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The src/dst offset. - int64_t offset = (int64_t) ni*params.hwc + hwi*params.c + ci; + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 x_v2 = IOTraits::zero(); IOType2 dy_v2 = IOTraits::zero(); - if( ci < params.c ) { - x_v2 = *reinterpret_cast(&reinterpret_cast(params.x )[offset]); - dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); + if (ci < params.c) { + x_v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); + dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); } // Extract the two half values. - float2 x_f2 = IOTraits::unpack(x_v2); + float2 x_f2 = IOTraits::unpack(x_v2); float2 dy_f2 = IOTraits::unpack(dy_v2); // X - X_mean. @@ -414,18 +403,18 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev; // Store the scaled values. - if( ci < params.c ) { - *reinterpret_cast(&reinterpret_cast(params.dx)[offset]) = IOTraits::pack(dx); + if (ci < params.c) { + *reinterpret_cast(&reinterpret_cast(params.dx)[offset]) = IOTraits::pack(dx); } } // Load gamma/beta and convert to half. - if( blockIdx.y > 0 || blockIdx.z > 0 || ci >= params.c ) { + if (blockIdx.y > 0 || blockIdx.z > 0 || ci >= params.c) { return; } // The base pointer for the gradients for gamma and beta. - float *dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n*params.groups*2]; + float *dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n * params.groups * 2]; // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved. int ck = blockIdx.x * params.channels_per_block + threadIdx.x; @@ -433,22 +422,20 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para // Load the FP32 version of dgamma and dbeta. float2 dgamma, dbeta; if (ck < params.c) { - dgamma.x = dgamma_beta_ptr[0*params.c + 0*blockDim.x + ck]; - dgamma.y = dgamma_beta_ptr[0*params.c + 1*blockDim.x + ck]; - dbeta.x = dgamma_beta_ptr[1*params.c + 0*blockDim.x + ck]; - dbeta.y = dgamma_beta_ptr[1*params.c + 1*blockDim.x + ck]; + dgamma.x = dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck]; + dgamma.y = dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck]; + dbeta.x = dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck]; + dbeta.y = dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck]; // Convert to half2 and store to memory. - *reinterpret_cast(&reinterpret_cast(params.dgamma)[ci]) = WTraits::pack(dgamma); - *reinterpret_cast(&reinterpret_cast(params.dbeta )[ci]) = WTraits::pack(dbeta); + *reinterpret_cast(&reinterpret_cast(params.dgamma)[ci]) = WTraits::pack(dgamma); + *reinterpret_cast(&reinterpret_cast(params.dbeta)[ci]) = WTraits::pack(dbeta); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params ¶ms, - cudaStream_t stream) { - +void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. dim3 grid; diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h old mode 100755 new mode 100644 index f1768aded..799bab9e6 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h @@ -2,11 +2,13 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ +#include + +#include + #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -14,71 +16,68 @@ // //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_FWD_SELECT(FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ - GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) \ - { \ - assert(false && "Not implemented"); \ +#define GN_FWD_SELECT(FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ + GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \ + assert(false && "Not implemented"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_FWD_RUNNER_SELECT(function) \ - GN_FWD_SELECT(_run, function) +#define GN_FWD_RUNNER_SELECT(function) GN_FWD_SELECT(_run, function) -#define GN_FWD_BLOCKS_PER_SM_SELECT(function) \ - GN_FWD_SELECT(_blocks_per_sm, function) +#define GN_FWD_BLOCKS_PER_SM_SELECT(function) GN_FWD_SELECT(_blocks_per_sm, function) //////////////////////////////////////////////////////////////////////////////////////////////////// -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) -GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) +GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128) @@ -86,20 +85,16 @@ GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160) //////////////////////////////////////////////////////////////////////////////////////////////////// -inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params ¶ms, - size_t &barriers_elts, - size_t &red_buffer_elts, - dim3 &grid, - const cudaDeviceProp &props) { - +inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params ¶ms, size_t &barriers_elts, + size_t &red_buffer_elts, dim3 &grid, const cudaDeviceProp &props) { // The pre-computed dimensions. - params.hw = params.h * params.w; + params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. - params.inv_hwc_per_group = 1.f / (float) (params.hw * params.channels_per_group); + params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Select the kernel. using Function_t = int (*)(); @@ -108,17 +103,15 @@ inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params ¶m GN_FWD_BLOCKS_PER_SM_SELECT(blocks_per_sm_function); // Define how many activations are computed per block. - if( params.hw >= 1024 && params.channels_per_group >= 80 || - (params.hw >= 256 && params.channels_per_group >= 160) ) - { + if (params.hw >= 1024 && params.channels_per_group >= 80 || (params.hw >= 256 && params.channels_per_group >= 160)) { params.acts_per_block = 8 * 16; - } else if( params.hw >= 512 ) { + } else if (params.hw >= 512) { params.acts_per_block = 16 * 32; - } else if( params.hw >= 256 ) { + } else if (params.hw >= 256) { params.acts_per_block = 16 * 16; - } else if( params.hw >= 128 ) { + } else if (params.hw >= 128) { params.acts_per_block = 8 * 16; - } else if ( params.hw > 0 ) { + } else if (params.hw > 0) { params.acts_per_block = 8 * 8; } else { // We should never be here if params are set correctly. @@ -146,7 +139,7 @@ inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params ¶m barriers_elts = blocks_per_slice > 1 ? grid.y * 2 : 0; // The number of elements in the reduction buffer (for the sums and sums of squared). - if( blocks_per_slice == 1 ) { + if (blocks_per_slice == 1) { red_buffer_elts = 0; } else { // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats. @@ -154,13 +147,9 @@ inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params ¶m } } -inline void group_norm_nhwc_fwd_one_pass_run(const Group_norm_nhwc_fwd_params ¶ms, - const dim3 &grid, - cudaStream_t stream) { - - using Function_t = void (*)(const Group_norm_nhwc_fwd_params &, - const dim3 &, - cudaStream_t); +inline void group_norm_nhwc_fwd_one_pass_run(const Group_norm_nhwc_fwd_params ¶ms, const dim3 &grid, + cudaStream_t stream) { + using Function_t = void (*)(const Group_norm_nhwc_fwd_params &, const dim3 &, cudaStream_t); Function_t runner; GN_FWD_RUNNER_SELECT(runner); diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh old mode 100755 new mode 100644 index e0a65f3ba..f11f8e2af --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh @@ -2,21 +2,22 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc.h" -#include "traits.h" #include + #include +#include "group_norm_nhwc.h" +#include "traits.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// // // F O R W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int ACTS_PER_BLOCK_, int CHANNELS_PER_GROUP_, int THREADS_PER_BLOCK_ > -__global__ __launch_bounds__(THREADS_PER_BLOCK_) - void group_norm_nhwc_fwd_one_pass_kernel(Group_norm_nhwc_fwd_params params) { - +template +__global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_fwd_one_pass_kernel( + Group_norm_nhwc_fwd_params params) { // The traits. using Traits = Traits_; // The IO traits. @@ -48,7 +49,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) // The number of activations that are loaded per loop. constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT; // The number of rows per thread. - constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP-1) / ACTS_PER_LOOP; + constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP; // The number of active threads. constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT; @@ -69,39 +70,38 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) const bool is_active = threadIdx.x < ACTIVE_THREADS; // Iterate over the iterms in the batch. - for( int ngi = blockIdx.y, step = 0; ngi < params.n*params.groups; ngi += gridDim.y, ++step ) { - + for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) { // The instance and the group. TODO: Use fast divmod? int ni = ngi / params.groups; int gi = ngi % params.groups; // The offset to the first activation loaded by that thread. - const int64_t offset = (int64_t) ni*params.hwc + gi*CHANNELS_PER_GROUP + ci; + const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci; // The pointer to the first activation loaded by that thread. - const IOType *x_ptr = &reinterpret_cast(params.x)[offset]; + const IOType *x_ptr = &reinterpret_cast(params.x)[offset]; // Load the activations into registers. IOType2 x[ACTS_PER_THREAD]; - #pragma unroll - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { - int hwj = hwi + ii*ACTS_PER_LOOP; +#pragma unroll + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { + int hwj = hwi + ii * ACTS_PER_LOOP; x[ii] = IOTraits::zero(); - if( is_active && hwj < params.hw ) { - x[ii] = *reinterpret_cast(&x_ptr[hwj*params.c]); + if (is_active && hwj < params.hw) { + x[ii] = *reinterpret_cast(&x_ptr[hwj * params.c]); } } // Compute the sum and the sum of squares for each thread. float2 sums = make_float2(0.f, 0.f); - #pragma unroll - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { +#pragma unroll + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { float2 f2 = IOTraits::unpack(x[ii]); sums.x += f2.x + f2.y; sums.y += f2.x * f2.x + f2.y * f2.y; } // Clear invalid threads. - if( ACTIVE_THREADS < THREADS_PER_BLOCK && !is_active ) { + if (ACTIVE_THREADS < THREADS_PER_BLOCK && !is_active) { sums = make_float2(0.f, 0.f); } @@ -111,32 +111,31 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) }); // The block leader stores to global memory, if needed. - if( gridDim.x > 1 ) { - + if (gridDim.x > 1) { // The index of the buffer (double-buffering). int red_buffer_idx = step & 1; // The barrier. - int *barrier = ¶ms.barriers[red_buffer_idx*gridDim.y + blockIdx.y]; + int *barrier = ¶ms.barriers[red_buffer_idx * gridDim.y + blockIdx.y]; // The offset to the reduction buffer. - int red_buffer_offset = red_buffer_idx*gridDim.x*gridDim.y*2; + int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2; // The reduction buffer. - float2 *red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); + float2 *red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); // The first thread stores its sums. - if( threadIdx.x == 0 ) { - red_buffer[blockIdx.x*gridDim.y + blockIdx.y] = sums; + if (threadIdx.x == 0) { + red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums; } // Make sure the data is in memory. - if( threadIdx.x == 0 ) { + if (threadIdx.x == 0) { spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x); } __syncthreads(); // Update the sums. - for( int ii = 0; ii < gridDim.x; ++ii ) { - if( ii != blockIdx.x && threadIdx.x == 0 ) { - float2 other_sums = red_buffer[ii*gridDim.y + blockIdx.y]; + for (int ii = 0; ii < gridDim.x; ++ii) { + if (ii != blockIdx.x && threadIdx.x == 0) { + float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y]; sums.x += other_sums.x; sums.y += other_sums.y; } @@ -144,12 +143,12 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) } // Store the result for other threads. - if( threadIdx.x == 0 ) { + if (threadIdx.x == 0) { smem_sums = sums; } // Store the results to global memory as well (for training). - if( params.sums != nullptr && blockIdx.x == 0 && threadIdx.x == 0 ) { + if (params.sums != nullptr && blockIdx.x == 0 && threadIdx.x == 0) { sums.x *= params.inv_hwc_per_group; sums.y *= params.inv_hwc_per_group; params.sums[ngi] = sums; @@ -159,10 +158,10 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) __syncthreads(); // Load gamma/beta. - float2 gamma_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.gamma)[gi*CHANNELS_PER_GROUP+ci])); - float2 beta_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.beta) [gi*CHANNELS_PER_GROUP+ci])); + float2 gamma_f2 = WTraits::unpack(*reinterpret_cast( + &reinterpret_cast(params.gamma)[gi * CHANNELS_PER_GROUP + ci])); + float2 beta_f2 = WTraits::unpack( + *reinterpret_cast(&reinterpret_cast(params.beta)[gi * CHANNELS_PER_GROUP + ci])); // Compute the mean. float mean = smem_sums.x * params.inv_hwc_per_group; @@ -172,11 +171,10 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var + params.epsilon); // The pointer to the first activation stored by that thread. - IOType *y_ptr = &reinterpret_cast(params.y)[offset]; + IOType *y_ptr = &reinterpret_cast(params.y)[offset]; // Iterate over the activations to normalize the activations and store the results. - for( int ii = 0; ii < ACTS_PER_THREAD; ++ii ) { - + for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Extract the two half values. float2 f2 = IOTraits::unpack(x[ii]); @@ -189,15 +187,15 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK_) f2.y = gamma_f2.y * f2.y + beta_f2.y; // Apply Swish if needed. - if( params.with_swish ) { + if (params.with_swish) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } // Store the scaled values. - int hwj = hwi + ii*ACTS_PER_LOOP; - if( is_active && hwj < params.hw ) { - *reinterpret_cast(&y_ptr[hwj*params.c]) = IOTraits::pack(f2); + int hwj = hwi + ii * ACTS_PER_LOOP; + if (is_active && hwj < params.hw) { + *reinterpret_cast(&y_ptr[hwj * params.c]) = IOTraits::pack(f2); } } } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu old mode 100755 new mode 100644 index c21ac05c7..eff05d51c --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu @@ -2,11 +2,13 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ +#include + +#include + #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -14,9 +16,8 @@ // //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int THREADS_PER_BLOCK > +template __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params) { - // The traits. using Traits = Traits_; // The IO traits. @@ -43,21 +44,20 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // The sums. float sum = 0.f, sum_sq = 0.f; // Iterate over the activations to compute the sums. - for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { - + for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The offset. - int64_t offset = (int64_t) ni*params.hwc + hwi*params.c + ci; + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 v2 = IOTraits::zero(); - if( ci < params.c ) { - v2 = *reinterpret_cast(&reinterpret_cast(params.x )[offset]); + if (ci < params.c) { + v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); } // Extract the two values. @@ -74,14 +74,14 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params int cj = threadIdx.x * 2 - params.channels_per_group * gj; // The data for the summations. - Group_sums inp {cj == 0 ? 1 : 0, sum, sum_sq}; + Group_sums inp{cj == 0 ? 1 : 0, sum, sum_sq}; // Do the segmented scan. Group_sums out; Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op()); // Store the results for the groups in shared memory (to produce coalesced stores later). - if( cj == params.channels_per_group - 2 /* 2 channels per thread */ ) { + if (cj == params.channels_per_group - 2 /* 2 channels per thread */) { smem[gj] = make_float2(out.sum, out.sum_sq); } @@ -92,7 +92,7 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params int gk = blockIdx.x * params.groups_per_block + threadIdx.x; // Threads that have nothing left to do, exit. - if( threadIdx.x >= params.groups_per_block || gk >= params.groups ) { + if (threadIdx.x >= params.groups_per_block || gk >= params.groups) { return; } @@ -100,29 +100,27 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params float2 sums = smem[threadIdx.x]; // Store to global memory. - atomicAdd(¶ms.zeroed_red_buffer[(2*ni+0)*params.groups + gk], sums.x); - atomicAdd(¶ms.zeroed_red_buffer[(2*ni+1)*params.groups + gk], sums.y); + atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x); + atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, - size_t &zeroed_red_buffer_elts) { - +void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, size_t &zeroed_red_buffer_elts) { // The pre-computed dimensions. - params.hw = params.h * params.w; + params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. - params.inv_hwc_per_group = 1.f / (float) (params.hw * params.channels_per_group); + params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; - if( params.c >= 1280 ) { + if (params.c >= 1280) { blocks_per_act_slice = 128 / params.n; - } else if( params.c >= 640 ) { + } else if (params.c >= 640) { blocks_per_act_slice = 256 / params.n; } else { blocks_per_act_slice = 512 / params.n; @@ -136,15 +134,14 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, // The number of channels per block. params.channels_per_block = 320; // Special case to deal with 30 channels per group. - if( params.channels_per_block % params.channels_per_group != 0 ) { + if (params.channels_per_block % params.channels_per_group != 0) { params.channels_per_block = 240; } // Special case to deal with 70 channels per group. - if( params.c == 2240 ) { + if (params.c == 2240) { params.channels_per_block = 280; - } - else if (params.c == 832){ + } else if (params.c == 832) { params.channels_per_block = 208; } @@ -180,9 +177,7 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params ¶ms, - cudaStream_t stream) { - +void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. dim3 grid; @@ -220,9 +215,8 @@ void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params ¶ms //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Traits_, int THREADS_PER_BLOCK > +template __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params params) { - // The traits. using Traits = Traits_; // The IO traits. @@ -249,18 +243,16 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para // Load the sum and sum of squares for the group. float sum = 0.f, sum_sq = 0.f; - if( gi < params.groups ) { - sum = params.zeroed_red_buffer[(2*ni+0)*params.groups + gi]; - sum_sq = params.zeroed_red_buffer[(2*ni+1)*params.groups + gi]; + if (gi < params.groups) { + sum = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi]; + sum_sq = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi]; } // Load gamma/beta. float2 gamma_f2, beta_f2; - if( ci < params.c ) { - gamma_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.gamma)[ci])); - beta_f2 = WTraits::unpack(*reinterpret_cast( - &reinterpret_cast(params.beta) [ci])); + if (ci < params.c) { + gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); + beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } // Compute the mean. @@ -273,18 +265,17 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. - for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { - + for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The src/dst offset. - int64_t offset = (int64_t) ni*params.hwc + hwi*params.c + ci; + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 v2 = IOTraits::zero(); - if( ci < params.c ) { - v2 = *reinterpret_cast(&reinterpret_cast(params.x )[offset]); + if (ci < params.c) { + v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); } // Extract the two values. @@ -299,31 +290,29 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para f2.y = gamma_f2.y * f2.y + beta_f2.y; // Apply Swish if needed. - if( params.with_swish ) { + if (params.with_swish) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } // Store the scaled values. - if( ci < params.c ) { - *reinterpret_cast(&reinterpret_cast(params.y)[offset]) = IOTraits::pack(f2); + if (ci < params.c) { + *reinterpret_cast(&reinterpret_cast(params.y)[offset]) = IOTraits::pack(f2); } } // Write the sums if needed. - if( params.sums != nullptr && gi < params.groups ) { + if (params.sums != nullptr && gi < params.groups) { float2 sums; - sums.x = sum * params.inv_hwc_per_group; + sums.x = sum * params.inv_hwc_per_group; sums.y = sum_sq * params.inv_hwc_per_group; - params.sums[ni*params.groups + gi] = sums; + params.sums[ni * params.groups + gi] = sums; } } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params ¶ms, - cudaStream_t stream) { - +void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. dim3 grid; diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu old mode 100755 new mode 100644 index 26c3fcad2..aa7ea5519 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 10, /* THREADS_PER_BLOCK */ 640) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu old mode 100755 new mode 100644 index 07a9f59b3..d75ff5723 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 112, /* THREADS_PER_BLOCK */ 448) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu index 118929ee3..64e3ee83e 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 12, /* THREADS_PER_BLOCK */ 384) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu old mode 100755 new mode 100644 index 3febb65f2..136f00728 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 120, /* THREADS_PER_BLOCK */ 480) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu old mode 100755 new mode 100644 index 7d0b20c15..f56f7ad16 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 128, /* THREADS_PER_BLOCK */ 512) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu old mode 100755 new mode 100644 index 72ce0b150..9d11ffea0 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 14, /* THREADS_PER_BLOCK */ 224) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu old mode 100755 new mode 100644 index 7e9c53c8a..2b3ed054e --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 16, /* THREADS_PER_BLOCK */ 256) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu old mode 100755 new mode 100644 index 0294a1c20..398e73345 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 160, /* THREADS_PER_BLOCK */ 640) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu old mode 100755 new mode 100644 index c227b7af3..3f7e039ae --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 20, /* THREADS_PER_BLOCK */ 640) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu old mode 100755 new mode 100644 index 4306f2b30..42406ab42 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 24, /* THREADS_PER_BLOCK */ 384) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu old mode 100755 new mode 100644 index 2e6a63856..c2adec60f --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 26, /* THREADS_PER_BLOCK */ 416) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu old mode 100755 new mode 100644 index 8cb671c30..696be0b7e --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 28, /* THREADS_PER_BLOCK */ 448) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu old mode 100755 new mode 100644 index b0bed2466..a80c3b884 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 30, /* THREADS_PER_BLOCK */ 480) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu old mode 100755 new mode 100644 index f4214a13b..77e3b9b2d --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 32, /* THREADS_PER_BLOCK */ 512) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu old mode 100755 new mode 100644 index 1b004b957..f8ce73987 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu @@ -3,9 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" - GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 4, /* THREADS_PER_BLOCK */ 128) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu old mode 100755 new mode 100644 index df9c69bc0..1b2f806b3 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 40, /* THREADS_PER_BLOCK */ 640) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu old mode 100755 new mode 100644 index 89729e7c5..457e702f9 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 42, /* THREADS_PER_BLOCK */ 672) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu old mode 100755 new mode 100644 index fbdbccbb0..6b7da2bcb --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 48, /* THREADS_PER_BLOCK */ 384) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu old mode 100755 new mode 100644 index e5af4321c..a924f7dcb --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 56, /* THREADS_PER_BLOCK */ 448) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu old mode 100755 new mode 100644 index d4473633a..15e9c71d2 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 60, /* THREADS_PER_BLOCK */ 480) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu old mode 100755 new mode 100644 index 92d0ed8c7..804f2e97c --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 64, /* THREADS_PER_BLOCK */ 512) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu old mode 100755 new mode 100644 index bd2fc6b8d..8a9f076aa --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 70, /* THREADS_PER_BLOCK */ 560) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu old mode 100755 new mode 100644 index 225d55cc1..4a150bcc3 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 8, /* THREADS_PER_BLOCK */ 128) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu old mode 100755 new mode 100644 index ece8fd677..acfca96cf --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 80, /* THREADS_PER_BLOCK */ 640) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu old mode 100755 new mode 100644 index 10273ec4b..116cc7cd6 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 84, /* THREADS_PER_BLOCK */ 672) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu old mode 100755 new mode 100644 index 969c26227..ba8ecc745 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 96, /* THREADS_PER_BLOCK */ 768) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu old mode 100755 new mode 100644 index 89ef8ff3e..63a75c315 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" +#include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 98, /* THREADS_PER_BLOCK */ 392) diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp index c80b49334..a51381324 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp @@ -12,23 +12,18 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CHECK_CUDA_STATUS(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ - cudaGetErrorString(status_)); \ - exit(1); \ - } \ +#define CHECK_CUDA_STATUS(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ } while (0) -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CHANNELS_LAST(x) \ - TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), \ - #x " must be channels last") +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CHANNELS_LAST(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) @@ -39,15 +34,13 @@ static bool initialized = false; static cudaDeviceProp props; -const std::unordered_set supported_c_values = { - 128, 256, 320, 384, 448, 512, 640, 768, 896, 960, 1024, 1280, 1344, - 1536, 1792, 1920, 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096}; +const std::unordered_set supported_c_values = {128, 256, 320, 384, 448, 512, 640, 768, + 896, 960, 1024, 1280, 1344, 1536, 1792, 1920, + 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096}; const std::unordered_set supported_groups_values = {16, 32}; -std::vector group_norm_fwd(torch::Tensor input, int groups, - torch::Tensor weight, - torch::Tensor bias, float eps, - int passes, bool with_swish = false) { +std::vector group_norm_fwd(torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias, + float eps, int passes, bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; @@ -62,11 +55,9 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, int w = input.size(3); // Check kernel constraints - TORCH_CHECK(supported_groups_values.count(groups), - "`groups` of {16, 32} are only supported but ", groups, + TORCH_CHECK(supported_groups_values.count(groups), "`groups` of {16, 32} are only supported but ", groups, " is passed"); - TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, - " is not included in supported_c_values"); + TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, " is not included in supported_c_values"); // Allocate tensors auto options = at::TensorOptions(at::kCUDA); @@ -81,10 +72,8 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, params_fwd.y = reinterpret_cast(output.data_ptr()); params_fwd.sums = reinterpret_cast(sums_d.data_ptr()); params_fwd.x = const_cast(reinterpret_cast(input.data_ptr())); - params_fwd.gamma = - const_cast(reinterpret_cast(weight.data_ptr())); - params_fwd.beta = - const_cast(reinterpret_cast(bias.data_ptr())); + params_fwd.gamma = const_cast(reinterpret_cast(weight.data_ptr())); + params_fwd.beta = const_cast(reinterpret_cast(bias.data_ptr())); params_fwd.epsilon = eps; params_fwd.n = n; params_fwd.h = h; @@ -131,8 +120,7 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, // Finalize the parameters. dim3 grid; if (passes == 1) { - group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, - red_buffer_elts, grid, props); + group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props); } else { group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts); } @@ -144,8 +132,7 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, // Allocate the buffer if needed. auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); params_fwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = - at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); params_fwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { @@ -158,12 +145,9 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, return {output, sums_d}; } -std::vector group_norm_bwd(torch::Tensor grad_output, - torch::Tensor sums, - torch::Tensor input, int groups, - torch::Tensor weight, - torch::Tensor bias, float eps, - int passes, bool with_swish = false) { +std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tensor sums, torch::Tensor input, + int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes, + bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; @@ -178,11 +162,9 @@ std::vector group_norm_bwd(torch::Tensor grad_output, int w = input.size(3); // Check kernel constraints - TORCH_CHECK(supported_groups_values.count(groups), - "`groups` of {16, 32} are only supported but ", groups, + TORCH_CHECK(supported_groups_values.count(groups), "`groups` of {16, 32} are only supported but ", groups, " is passed"); - TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, - " is not included in supported_c_values"); + TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, " is not included in supported_c_values"); // Allocate tensors auto options = at::TensorOptions(at::kCUDA); @@ -199,16 +181,12 @@ std::vector group_norm_bwd(torch::Tensor grad_output, params_bwd.dx = reinterpret_cast(grad_input.data_ptr()); params_bwd.dgamma = reinterpret_cast(grad_weight.data_ptr()); params_bwd.dbeta = reinterpret_cast(grad_bias.data_ptr()); - params_bwd.sums = - const_cast(reinterpret_cast(sums.data_ptr())); - params_bwd.dy = - const_cast(reinterpret_cast(grad_output.data_ptr())); + params_bwd.sums = const_cast(reinterpret_cast(sums.data_ptr())); + params_bwd.dy = const_cast(reinterpret_cast(grad_output.data_ptr())); params_bwd.x = const_cast(reinterpret_cast(input.data_ptr())); ; - params_bwd.gamma = - const_cast(reinterpret_cast(weight.data_ptr())); - params_bwd.beta = - const_cast(reinterpret_cast(bias.data_ptr())); + params_bwd.gamma = const_cast(reinterpret_cast(weight.data_ptr())); + params_bwd.beta = const_cast(reinterpret_cast(bias.data_ptr())); ; params_bwd.epsilon = eps; params_bwd.n = n; @@ -256,9 +234,7 @@ std::vector group_norm_bwd(torch::Tensor grad_output, // Finalize the parameters. dim3 grid; if (passes == 1) { - group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, - red_buffer_elts, zeroed_red_buffer_elts, - grid, props); + group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props); } else { group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts); } @@ -270,8 +246,7 @@ std::vector group_norm_bwd(torch::Tensor grad_output, // Allocate the buffer if needed. auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); params_bwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = - at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); params_bwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { diff --git a/apex/contrib/csrc/group_norm/macros.h b/apex/contrib/csrc/group_norm/macros.h old mode 100755 new mode 100644 index 674cc1909..4f6b5b84c --- a/apex/contrib/csrc/group_norm/macros.h +++ b/apex/contrib/csrc/group_norm/macros.h @@ -4,67 +4,53 @@ */ #define GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -void group_norm_nhwc_ ## PASS_NAME ## _one_pass_ ## CHANNELS_PER_GROUP ## _ ## ACTS_PER_BLOCK ## _ ## Traits ## _run( \ - const Group_norm_nhwc_ ## PASS_NAME ## _params ¶ms, \ - const dim3 &grid, \ - cudaStream_t stream) - -#define GN_ONE_PASS_RUN_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ - \ - auto kernel = group_norm_nhwc_ ## PASS_NAME ## _one_pass_kernel; \ - \ - const Group_norm_nhwc_ ## PASS_NAME ## _params *params_ = ¶ms; \ - if( grid.x > 1 ) { \ - CHECK_CUDA(cudaLaunchCooperativeKernel((const void*) kernel, \ - grid, \ - dim3(THREADS_PER_BLOCK), \ - (void**) ¶ms_, \ - 0, \ - stream)); \ - \ - } else { \ - CHECK_CUDA(cudaLaunchKernel((const void*) kernel, \ - grid, \ - dim3(THREADS_PER_BLOCK), \ - (void**) ¶ms_, \ - 0, \ - stream)); \ - \ - } \ - \ - CHECK_CUDA(cudaGetLastError()); \ -} + void group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_run( \ + const Group_norm_nhwc_##PASS_NAME##_params ¶ms, const dim3 &grid, cudaStream_t stream) + +#define GN_ONE_PASS_RUN_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ + GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ + auto kernel = \ + group_norm_nhwc_##PASS_NAME##_one_pass_kernel; \ + \ + const Group_norm_nhwc_##PASS_NAME##_params *params_ = ¶ms; \ + if (grid.x > 1) { \ + CHECK_CUDA(cudaLaunchCooperativeKernel((const void *)kernel, grid, dim3(THREADS_PER_BLOCK), (void **)¶ms_, \ + 0, stream)); \ + \ + } else { \ + CHECK_CUDA(cudaLaunchKernel((const void *)kernel, grid, dim3(THREADS_PER_BLOCK), (void **)¶ms_, 0, stream)); \ + } \ + \ + CHECK_CUDA(cudaGetLastError()); \ + } ////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -int group_norm_nhwc_ ## PASS_NAME ## _one_pass_ ## CHANNELS_PER_GROUP ## _ ## ACTS_PER_BLOCK ## _ ## Traits ## _blocks_per_sm() - -#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ - \ - auto kernel = group_norm_nhwc_ ## PASS_NAME ## _one_pass_kernel; \ - \ - int blocks_per_sm = 0; \ - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, \ - kernel, \ - THREADS_PER_BLOCK, \ - 0)); \ - \ - CHECK_CUDA(cudaGetLastError()); \ - return blocks_per_sm; \ -} +#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME) \ + int group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_blocks_per_sm() + +#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ + GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ + auto kernel = \ + group_norm_nhwc_##PASS_NAME##_one_pass_kernel; \ + \ + int blocks_per_sm = 0; \ + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, THREADS_PER_BLOCK, 0)); \ + \ + CHECK_CUDA(cudaGetLastError()); \ + return blocks_per_sm; \ + } ////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_ONE_PASS_(FUNCTION, Traits, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ - FUNCTION(Traits, 512, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - FUNCTION(Traits, 256, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - FUNCTION(Traits, 128, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - FUNCTION(Traits, 64, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); + FUNCTION(Traits, 512, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ + FUNCTION(Traits, 256, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ + FUNCTION(Traits, 128, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ + FUNCTION(Traits, 64, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); -#define GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ +#define GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ @@ -75,7 +61,7 @@ GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GRO GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); -#define GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ +#define GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ @@ -86,7 +72,7 @@ GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GRO GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); -#define GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ +#define GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ @@ -97,112 +83,145 @@ GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GRO GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); -#define GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ - GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ +#define GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ + GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ + PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ -GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) + GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ + GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) #define GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ -GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, fwd) + GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, fwd) #define GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ -GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, bwd) + GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, bwd) #define GN_FWD_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ -GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ -GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) + GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ + GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, PASS_NAME) \ - if( params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && params.precision == PrecisionMode::PRECISION ) { \ - function = group_norm_nhwc_ ## PASS_NAME ## _one_pass_ ## CHANNELS_PER_GROUP ## _ ## ACTS_PER_BLOCK ## _ ## Traits ## FUNC_POSTFIX ; \ +#define GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \ + CHANNELS_PER_GROUP, PASS_NAME) \ + if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && \ + params.precision == PrecisionMode::PRECISION) { \ + function = \ + group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \ } else -#define GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, PASS_NAME, LIMIT_CPG) \ - if( params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && params.precision == PrecisionMode::PRECISION && CHANNELS_PER_GROUP >= LIMIT_CPG ) { \ - function = group_norm_nhwc_ ## PASS_NAME ## _one_pass_ ## CHANNELS_PER_GROUP ## _ ## ACTS_PER_BLOCK ## _ ## Traits ## FUNC_POSTFIX ; \ +#define GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \ + CHANNELS_PER_GROUP, PASS_NAME, LIMIT_CPG) \ + if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && \ + params.precision == PrecisionMode::PRECISION && CHANNELS_PER_GROUP >= LIMIT_CPG) { \ + function = \ + group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \ } else -#define GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Traits, PRECISION, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, PASS_NAME) \ - GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 1024, 128, CHANNELS_PER_GROUP, PASS_NAME, 80) \ - GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 128, CHANNELS_PER_GROUP, PASS_NAME, 160) \ - GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 512, 512, CHANNELS_PER_GROUP, PASS_NAME) \ - GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 256, CHANNELS_PER_GROUP, PASS_NAME) \ - GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 128, 128, CHANNELS_PER_GROUP, PASS_NAME) \ - GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 0, 64, CHANNELS_PER_GROUP, PASS_NAME) +#define GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Traits, PRECISION, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, PASS_NAME) \ + GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 1024, 128, CHANNELS_PER_GROUP, \ + PASS_NAME, 80) \ + GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 128, CHANNELS_PER_GROUP, \ + PASS_NAME, 160) \ + GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 512, 512, CHANNELS_PER_GROUP, PASS_NAME) \ + GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 256, CHANNELS_PER_GROUP, PASS_NAME) \ + GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 128, 128, CHANNELS_PER_GROUP, PASS_NAME) \ + GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 0, 64, CHANNELS_PER_GROUP, PASS_NAME) #define GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, fwd) + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, fwd) #define GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) \ -GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, FUNC_POSTFIX, function, bwd) + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) \ + GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, \ + FUNC_POSTFIX, function, bwd) //////////////////////////////////////////////////////////////////////////////////////////////////// -#define GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, PASS_NAME) \ -GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) \ -GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) +#define GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, PASS_NAME) \ + GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) \ + GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) -#define GN_FWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) \ -GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, fwd) +#define GN_FWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, fwd) -#define GN_BWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) \ -GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, bwd) +#define GN_BWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, bwd) //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CALL_TWO_PASS_KERNEL(Kernel, Precision) \ - if( params.channels_per_block == 320 ) { \ +#define CALL_TWO_PASS_KERNEL(Kernel, Precision) \ + if (params.channels_per_block == 320) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 280 ) { \ + } else if (params.channels_per_block == 280) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 208 ) { \ + } else if (params.channels_per_block == 208) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 240 ) { \ + } else if (params.channels_per_block == 240) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 512 ) { \ + } else if (params.channels_per_block == 512) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 448 ) { \ + } else if (params.channels_per_block == 448) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 384 ) { \ + } else if (params.channels_per_block == 384) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 256 ) { \ + } else if (params.channels_per_block == 256) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 128 ) { \ - Kernel<<>>(params); \ - } else if( params.channels_per_block == 336 ) { \ + } else if (params.channels_per_block == 128) { \ + Kernel<<>>(params); \ + } else if (params.channels_per_block == 336) { \ Kernel<<>>(params); \ - } else if( params.channels_per_block == 392 ) { \ + } else if (params.channels_per_block == 392) { \ Kernel<<>>(params); \ - } else { \ - assert(false); \ + } else { \ + assert(false); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/group_norm/traits.h b/apex/contrib/csrc/group_norm/traits.h old mode 100755 new mode 100644 index c1eccc174..356c6be64 --- a/apex/contrib/csrc/group_norm/traits.h +++ b/apex/contrib/csrc/group_norm/traits.h @@ -4,67 +4,52 @@ */ #pragma once +#include +#include +#include #include #include #include #include -#include -#include -#include //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Fp32 -{ +struct Fp32 { // Type is float32_t using Type = float; // Doubled type using Type2 = float2; // Unpack input to accumulators type - static inline __device__ float2 unpack(const float2& f2) - { - return f2; - } + static inline __device__ float2 unpack(const float2& f2) { return f2; } // Pack the accumulators into outputs. - static inline __device__ float2 pack(const float2& f2) - { - return f2; - } + static inline __device__ float2 pack(const float2& f2) { return f2; } - static inline __device__ float2 zero() - { - return {0.f, 0.f}; - } + static inline __device__ float2 zero() { return {0.f, 0.f}; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Fp16 -{ +struct Fp16 { // Type is __half using Type = __half; // Doubled type using Type2 = __half2; // Unpack input to accumulators type - static inline __device__ float2 unpack(const __half2& h2) - { + static inline __device__ float2 unpack(const __half2& h2) { // FIXME(nkorobov): __half22float2 makes compilation error in container - return {__half2float(h2.x), - __half2float(h2.y)}; + return {__half2float(h2.x), __half2float(h2.y)}; } // Pack the accumulators into outputs. - static inline __device__ __half2 pack(const float2& f2) - { + static inline __device__ __half2 pack(const float2& f2) { // FIXME(nkorobov): __float22half2_rn makes compilation error in container return {__float2half_rn(f2.x), __float2half_rn(f2.y)}; } - static inline __device__ __half2 zero() - { + static inline __device__ __half2 zero() { uint32_t zero = 0; return *reinterpret_cast<__half2*>(&zero); } @@ -72,55 +57,46 @@ struct Fp16 //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Bf16 -{ +struct Bf16 { // Type is __nv_bfloat16 using Type = __nv_bfloat16; // Doubled type using Type2 = __nv_bfloat162; // Unpack input to accumulators type - static inline __device__ float2 unpack(const __nv_bfloat162& h2) - { + static inline __device__ float2 unpack(const __nv_bfloat162& h2) { // FIXME(nkorobov): __half22float2 makes compilation error in container - return {__bfloat162float(h2.x), - __bfloat162float(h2.y)}; + return {__bfloat162float(h2.x), __bfloat162float(h2.y)}; } // Pack the accumulators into outputs. - static inline __device__ __nv_bfloat162 pack(const float2& f2) - { + static inline __device__ __nv_bfloat162 pack(const float2& f2) { // FIXME(nkorobov): __float22bfloat162_rn makes compilation error in container - return {__float2bfloat16_rn(f2.x), __float2bfloat16_rn(f2.y)}; + return {__float2bfloat16_rn(f2.x), __float2bfloat16_rn(f2.y)}; } - static inline __device__ __nv_bfloat162 zero() - { + static inline __device__ __nv_bfloat162 zero() { uint32_t zero = 0; return *reinterpret_cast<__nv_bfloat162*>(&zero); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Fp32IOFp16W -{ +struct Fp32IOFp16W { // IO traits using IOTraits = Fp32; // Weigths traits using WTraits = Fp16; }; -struct Fp32IOBf16W -{ +struct Fp32IOBf16W { // IO traits using IOTraits = Fp32; // Weigths traits using WTraits = Bf16; }; - -struct Fp32IOFp32W -{ +struct Fp32IOFp32W { // IO traits using IOTraits = Fp32; // Weigths traits @@ -129,24 +105,21 @@ struct Fp32IOFp32W //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Fp16IOFp16W -{ +struct Fp16IOFp16W { // IO traits using IOTraits = Fp16; // Weigths traits using WTraits = Fp16; }; -struct Fp16IOBf16W -{ +struct Fp16IOBf16W { // IO traits using IOTraits = Fp16; // Weigths traits using WTraits = Bf16; }; -struct Fp16IOFp32W -{ +struct Fp16IOFp32W { // IO traits using IOTraits = Fp16; // Weigths traits @@ -154,24 +127,21 @@ struct Fp16IOFp32W }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Bf16IOFp16W -{ +struct Bf16IOFp16W { // IO traits using IOTraits = Bf16; // Weigths traits using WTraits = Fp16; }; -struct Bf16IOBf16W -{ +struct Bf16IOBf16W { // IO traits using IOTraits = Bf16; // Weigths traits using WTraits = Bf16; }; -struct Bf16IOFp32W -{ +struct Bf16IOFp32W { // IO traits using IOTraits = Bf16; // Weigths traits diff --git a/apex/contrib/csrc/group_norm_v2/gn.cpp b/apex/contrib/csrc/group_norm_v2/gn.cpp index 80d2f4c3b..61cfd2177 100644 --- a/apex/contrib/csrc/group_norm_v2/gn.cpp +++ b/apex/contrib/csrc/group_norm_v2/gn.cpp @@ -1,105 +1,112 @@ -#include -#include - #include "gn.hpp" +#include +#include namespace group_norm_v2 { -torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, std::optional mean_var_out, int sm_margin) { - if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { - throw std::invalid_argument("gn dtype mismatch"); - } - torch::Tensor out = torch::empty_like(x); - float *ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int device_id = at::cuda::getCurrentCUDAStream().device().index(); - group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { - group_norm_v2::gn_cuda( - (half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, - nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { - group_norm_v2::gn_cuda( - (__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, - nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else { - throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); - } - torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; - if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); - } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { - group_norm_v2::gn_cuda( - (half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, - red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { - group_norm_v2::gn_cuda( - (__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, - red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else { - throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); - } - return out; +torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, + std::optional mean_var_out, int sm_margin) { + if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { + throw std::invalid_argument("gn dtype mismatch"); + } + torch::Tensor out = torch::empty_like(x); + float *ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr() : nullptr; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + int device_id = at::cuda::getCurrentCUDAStream().device().index(); + group_norm_v2::Meta meta; + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + group_norm_v2::gn_cuda((half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, + ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + group_norm_v2::gn_cuda((__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), + (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), eps, silu, x.size(0), + x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, + nullptr, sm_margin, stream, device_id, &meta, true); + } else { + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); + } + torch::Tensor red_buffer = + torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + thread_local torch::Tensor barrier; + if (barrier.size(0) < meta.barrier_size) { + barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + } + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + group_norm_v2::gn_cuda((half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, + ptr_mean_var_out, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, + stream, device_id, nullptr, false); + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + group_norm_v2::gn_cuda((__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), + (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), eps, silu, x.size(0), + x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, + red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, + nullptr, false); + } else { + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); + } + return out; } -auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, float eps, bool silu, int num_groups, int sm_margin) { - if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { - throw std::invalid_argument("gn_bwd dtype mismatch"); - } - torch::Tensor grad_input = torch::empty_like(x); - torch::Tensor grad_weight = torch::empty_like(w); - torch::Tensor grad_bias = torch::empty_like(w); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int device_id = at::cuda::getCurrentCUDAStream().device().index(); - group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { - group_norm_v2::gn_bwd_cuda( - (half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(), - (half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, - nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { - group_norm_v2::gn_bwd_cuda( - (__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(), - (__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, - nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else { - throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); - } - torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; - if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); - } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { - group_norm_v2::gn_bwd_cuda( - (half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(), - (half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, - red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { - group_norm_v2::gn_bwd_cuda( - (__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(), - (__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr(), - eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, - red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else { - throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); - } - return std::make_tuple(grad_input, grad_weight, grad_bias); +auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, + float eps, bool silu, int num_groups, int sm_margin) { + if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { + throw std::invalid_argument("gn_bwd dtype mismatch"); + } + torch::Tensor grad_input = torch::empty_like(x); + torch::Tensor grad_weight = torch::empty_like(w); + torch::Tensor grad_bias = torch::empty_like(w); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + int device_id = at::cuda::getCurrentCUDAStream().device().index(); + group_norm_v2::Meta meta; + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + group_norm_v2::gn_bwd_cuda((half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), + (half *)grad_bias.data_ptr(), (half *)grad_output.data_ptr(), (half *)x.data_ptr(), + (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr(), eps, silu, + x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, + sm_margin, stream, device_id, &meta, true); + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + group_norm_v2::gn_bwd_cuda((__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), + (__nv_bfloat16 *)grad_bias.data_ptr(), (__nv_bfloat16 *)grad_output.data_ptr(), + (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), + (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), + x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin, + stream, device_id, &meta, true); + } else { + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); + } + torch::Tensor red_buffer = + torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + thread_local torch::Tensor barrier; + if (barrier.size(0) < meta.barrier_size) { + barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + } + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + group_norm_v2::gn_bwd_cuda( + (half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(), + (half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), + mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, + red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + group_norm_v2::gn_bwd_cuda((__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), + (__nv_bfloat16 *)grad_bias.data_ptr(), (__nv_bfloat16 *)grad_output.data_ptr(), + (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), + (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), + x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr(), + barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); + } else { + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); + } + return std::make_tuple(grad_input, grad_weight, grad_bias); } } // namespace group_norm_v2 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); - m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); + m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), + py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); + m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), + py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); } diff --git a/apex/contrib/csrc/group_norm_v2/gn.hpp b/apex/contrib/csrc/group_norm_v2/gn.hpp index 7f1ec7c21..f8fa32d4e 100644 --- a/apex/contrib/csrc/group_norm_v2/gn.hpp +++ b/apex/contrib/csrc/group_norm_v2/gn.hpp @@ -1,28 +1,32 @@ #pragma once -#include #include +#include namespace group_norm_v2 { struct Meta { - int64_t red_buffer_size; - int64_t barrier_size; - int BLOCK_DIM_X; - int C_PER_BLOCK; - int ROWS_PER_BLOCK; - int VEC_ELEMS; - bool LOAD_TWICE; - int BLOCKS_PER_SM; - bool HARDWARE_CLUSTER; - int wgrad_sync_method; + int64_t red_buffer_size; + int64_t barrier_size; + int BLOCK_DIM_X; + int C_PER_BLOCK; + int ROWS_PER_BLOCK; + int VEC_ELEMS; + bool LOAD_TWICE; + int BLOCKS_PER_SM; + bool HARDWARE_CLUSTER; + int wgrad_sync_method; }; -template -void gn_cuda(T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); +template +void gn_cuda(T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, + int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, + cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); -template -void gn_bwd_cuda(T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); +template +void gn_bwd_cuda(T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, + float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, + unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); } // namespace group_norm_v2 diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda.cu index 05ed8ec54..5b77f365c 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda.cu @@ -1,47 +1,53 @@ -#include "gn.hpp" +#include +#include +#include #include #include #include -#include -#include -#include - -#include "gn_utils.hpp" +#include "gn.hpp" #include "gn_dispatch_hw_c.hpp" +#include "gn_utils.hpp" - -#define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) [&] { \ - if (num_groups == 16 && silu == true) { constexpr int NUM_GROUPS = 16; constexpr bool SILU = true; return __VA_ARGS__(); } \ - if (num_groups == 32 && silu == false) { constexpr int NUM_GROUPS = 32; constexpr bool SILU = false; return __VA_ARGS__(); } \ - throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + std::to_string(silu)); \ - }() +#define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) \ + [&] { \ + if (num_groups == 16 && silu == true) { \ + constexpr int NUM_GROUPS = 16; \ + constexpr bool SILU = true; \ + return __VA_ARGS__(); \ + } \ + if (num_groups == 32 && silu == false) { \ + constexpr int NUM_GROUPS = 32; \ + constexpr bool SILU = false; \ + return __VA_ARGS__(); \ + } \ + throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + \ + std::to_string(silu)); \ + }() namespace group_norm_v2 { -template +template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)); -template +template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)); -template +template void gn_cuda(GN_CUDA_HOST_PARAMS(T)) { - DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { - DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { - return gn_cuda_single_shape(GN_CUDA_HOST_ARGS); - }); - }); + DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { + DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, + [&] { return gn_cuda_single_shape(GN_CUDA_HOST_ARGS); }); + }); } -template +template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) { - DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { - DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { - return gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_ARGS); - }); - }); + DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { + DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, + [&] { return gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_ARGS); }); + }); } template void gn_cuda(GN_CUDA_HOST_PARAMS(half)); diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh b/apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh index 36c568e5f..3fb847593 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh @@ -1,464 +1,522 @@ #pragma once +#include +#include +#include + #include #include -#include -#include -#include - -#include "gn_utils.hpp" #include "gn_cuda_kernel.cuh" - +#include "gn_utils.hpp" namespace group_norm_v2 { -#define DISPATCH_LOWER_BOUND_N(VALUE, CONST_NAME, ...) [&] { \ - if (VALUE >= 16) { constexpr int CONST_NAME = 16; return __VA_ARGS__(); } \ - if (VALUE >= 8) { constexpr int CONST_NAME = 8; return __VA_ARGS__(); } \ - if (VALUE >= 4) { constexpr int CONST_NAME = 4; return __VA_ARGS__(); } \ - if (VALUE >= 2) { constexpr int CONST_NAME = 2; return __VA_ARGS__(); } \ - if (VALUE >= 1) { constexpr int CONST_NAME = 1; return __VA_ARGS__(); } \ +#define DISPATCH_LOWER_BOUND_N(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE >= 16) { \ + constexpr int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } \ + if (VALUE >= 8) { \ + constexpr int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } \ + if (VALUE >= 4) { \ + constexpr int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } \ + if (VALUE >= 2) { \ + constexpr int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } \ + if (VALUE >= 1) { \ + constexpr int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } \ throw std::invalid_argument("DISPATCH_LOWER_BOUND_N " + std::to_string(VALUE)); \ - }() - -#define DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, ...) [&] { \ - if (runtime_cuda_arch == 1000 && sm_count >= 148) { constexpr int RUNTIME_CUDA_ARCH = 1000, LB_SM_COUNT = 148; return __VA_ARGS__(); } \ - throw std::invalid_argument("DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT " + std::to_string(runtime_cuda_arch) + " " + std::to_string(sm_count)); \ - }() - -#define DISPATCH_SM_MARGIN(VALUE, CONST_NAME, ...) [&] { \ - if (VALUE == 0) { constexpr int CONST_NAME = 0; return __VA_ARGS__(); } \ - if (VALUE == 32) { constexpr int CONST_NAME = 32; return __VA_ARGS__(); } \ + }() + +#define DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, ...) \ + [&] { \ + if (runtime_cuda_arch == 1000 && sm_count >= 148) { \ + constexpr int RUNTIME_CUDA_ARCH = 1000, LB_SM_COUNT = 148; \ + return __VA_ARGS__(); \ + } \ + throw std::invalid_argument("DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT " + std::to_string(runtime_cuda_arch) + \ + " " + std::to_string(sm_count)); \ + }() + +#define DISPATCH_SM_MARGIN(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE == 0) { \ + constexpr int CONST_NAME = 0; \ + return __VA_ARGS__(); \ + } \ + if (VALUE == 32) { \ + constexpr int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ throw std::invalid_argument("DISPATCH_SM_MARGIN " + std::to_string(VALUE)); \ - }() - + }() inline constexpr int get_max_cuda_arch() { - int cuda_arch_list[] = {__CUDA_ARCH_LIST__}; - int max_cuda_arch = -1; - for (int cuda_arch_item : cuda_arch_list) { - if (cuda_arch_item > max_cuda_arch) { - max_cuda_arch = cuda_arch_item; - } + int cuda_arch_list[] = {__CUDA_ARCH_LIST__}; + int max_cuda_arch = -1; + for (int cuda_arch_item : cuda_arch_list) { + if (cuda_arch_item > max_cuda_arch) { + max_cuda_arch = cuda_arch_item; } - return max_cuda_arch; + } + return max_cuda_arch; } -template +template constexpr auto compute_gn_params() { - constexpr int C = G * CPG; - - // Initialize each variable to comply with C++17 - int BLOCK_DIM_X = 0; - int C_PER_BLOCK = 0; - int ROWS_PER_BLOCK = 0; - bool LOAD_TWICE = false; - int BLOCKS_PER_SM = 0; - WgradSyncMethod wgrad_sync_method = WGRAD_SYNC_UNSPECIFIED; - - // There are two tiling strategies: - // - block sync: each block handles a whole group, i.e., a multiple of (G * HW) elements - // - virtual cluster sync: each virtual cluster handles a group - // Block sync can avoid cross-block synchronization latency, but it may cause low occupancy. - // Use block sync if the IO size is small, when latency rather than occupancy dominates the kernel running time. - - // Elements to load for forward pass is `x`, elements to load for backward pass are `x` and `grad_output`, hence there is a factor of (1 + BWD) - if (HW * CPG * (1 + BWD) * sizeof(T) <= 20480) { - // Strategy 1: block sync - C_PER_BLOCK = CPG; - ROWS_PER_BLOCK = HW; - BLOCK_DIM_X = lcm(32, C_PER_BLOCK); - while (BLOCK_DIM_X < 256) { - BLOCK_DIM_X *= 2; + constexpr int C = G * CPG; + + // Initialize each variable to comply with C++17 + int BLOCK_DIM_X = 0; + int C_PER_BLOCK = 0; + int ROWS_PER_BLOCK = 0; + bool LOAD_TWICE = false; + int BLOCKS_PER_SM = 0; + WgradSyncMethod wgrad_sync_method = WGRAD_SYNC_UNSPECIFIED; + + // There are two tiling strategies: + // - block sync: each block handles a whole group, i.e., a multiple of (G * HW) elements + // - virtual cluster sync: each virtual cluster handles a group + // Block sync can avoid cross-block synchronization latency, but it may cause low occupancy. + // Use block sync if the IO size is small, when latency rather than occupancy dominates the kernel running time. + + // Elements to load for forward pass is `x`, elements to load for backward pass are `x` and `grad_output`, hence there + // is a factor of (1 + BWD) + if (HW * CPG * (1 + BWD) * sizeof(T) <= 20480) { + // Strategy 1: block sync + C_PER_BLOCK = CPG; + ROWS_PER_BLOCK = HW; + BLOCK_DIM_X = lcm(32, C_PER_BLOCK); + while (BLOCK_DIM_X < 256) { + BLOCK_DIM_X *= 2; + } + BLOCKS_PER_SM = 1; + // The size of registers is 65536 registers * 4 bytes per register. + // We have to leave some room for other variables and compiler optimizations, + // so we use 36000 as the threshold. + LOAD_TWICE = BLOCKS_PER_SM * ROWS_PER_BLOCK * C_PER_BLOCK * (1 + BWD) * sizeof(T) > 36000 * 4; + } else { + // Strategy 2: virtual cluster sync + // A virtual cluster is a group of blocks that are synchronized with each other. + // Each group, i.e., a multiple of (G * HW) elements, should be handled on the same virtual cluster. + // If the virtual cluster size is supported by the hardware, HARDWARE_CLUSTER is preferred; + // otherwise, cooperative groups are used (i.e., PERSISTENT kernels). + int c_per_cluster = lcm(128 / (int)sizeof(T), CPG); + + C_PER_BLOCK = c_per_cluster; + BLOCK_DIM_X = C_PER_BLOCK == 320 ? 320 : 480; + + // Maximum number of rows that should reside in registers + int register_max_rows = 36000 * 4 / (C_PER_BLOCK * (1 + BWD) * sizeof(T)); + + std::tuple best_candidate{}; + BLOCKS_PER_SM = 0; + ROWS_PER_BLOCK = 0; + for (int blocks_per_sm = 1; blocks_per_sm <= 3; blocks_per_sm++) { + for (int rows_per_block = HW; rows_per_block >= 1; rows_per_block /= 2) { + int virtual_cluster_size = (HW / rows_per_block) * (c_per_cluster / C_PER_BLOCK); + if (virtual_cluster_size > blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)) { + continue; } - BLOCKS_PER_SM = 1; - // The size of registers is 65536 registers * 4 bytes per register. - // We have to leave some room for other variables and compiler optimizations, - // so we use 36000 as the threshold. - LOAD_TWICE = BLOCKS_PER_SM * ROWS_PER_BLOCK * C_PER_BLOCK * (1 + BWD) * sizeof(T) > 36000 * 4; - } else { - // Strategy 2: virtual cluster sync - // A virtual cluster is a group of blocks that are synchronized with each other. - // Each group, i.e., a multiple of (G * HW) elements, should be handled on the same virtual cluster. - // If the virtual cluster size is supported by the hardware, HARDWARE_CLUSTER is preferred; - // otherwise, cooperative groups are used (i.e., PERSISTENT kernels). - int c_per_cluster = lcm(128 / (int)sizeof(T), CPG); - - C_PER_BLOCK = c_per_cluster; - BLOCK_DIM_X = C_PER_BLOCK == 320 ? 320 : 480; -\ - // Maximum number of rows that should reside in registers - int register_max_rows = 36000 * 4 / (C_PER_BLOCK * (1 + BWD) * sizeof(T)); - - std::tuple best_candidate{}; - BLOCKS_PER_SM = 0; - ROWS_PER_BLOCK = 0; - for (int blocks_per_sm = 1; blocks_per_sm <= 3; blocks_per_sm++) { - for (int rows_per_block = HW; rows_per_block >= 1; rows_per_block /= 2) { - int virtual_cluster_size = (HW / rows_per_block) * (c_per_cluster / C_PER_BLOCK); - if (virtual_cluster_size > blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)) { - continue; - } - int num_clusters = blocks_per_sm * (LB_SM_COUNT - SM_MARGIN) / virtual_cluster_size; - int num_tasks = LB_N * (C / c_per_cluster); - int num_waves = up_div(num_tasks, num_clusters); - bool load_twice = rows_per_block > register_max_rows / blocks_per_sm; - - // Wave utilization: the percent of SMs that are used for each wave - // For example, SM_COUNT=100 and VIRTUAL_CLUSTER_SIZE=64, - // if BLOCKS_PER_SM=1, num_clusters=1, wave_util=64%; - // if BLOCKS_PER_SM=2, num_clusters=3, wave_util=96%. - // This helps select a good number of BLOCKS_PER_SM - int wave_util = 10000 * std::min(num_tasks, num_clusters) * virtual_cluster_size / (blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)); - - decltype(best_candidate) candidate = { - true, - !load_twice, // Prefer no load twice - !(num_waves >= 2 && blocks_per_sm == 1), // When there are multiple waves, prefer multiple blocks per SM to ensure overlapping - -num_waves, // Prefer fewer waves - std::min(9000, wave_util), // Prefer high wave utilization - -blocks_per_sm, // Prefer fewer blocks per SM in order to reduce threads overhead - }; - if (candidate > best_candidate) { - // Assign each element respectively to comply with C++17 - std::get<0>(best_candidate) = std::get<0>(candidate); - std::get<1>(best_candidate) = std::get<1>(candidate); - std::get<2>(best_candidate) = std::get<2>(candidate); - std::get<3>(best_candidate) = std::get<3>(candidate); - std::get<4>(best_candidate) = std::get<4>(candidate); - std::get<5>(best_candidate) = std::get<5>(candidate); - static_assert(std::tuple_size::value == 6, "missing assignments"); - - BLOCKS_PER_SM = blocks_per_sm; - ROWS_PER_BLOCK = rows_per_block; - } - } + int num_clusters = blocks_per_sm * (LB_SM_COUNT - SM_MARGIN) / virtual_cluster_size; + int num_tasks = LB_N * (C / c_per_cluster); + int num_waves = up_div(num_tasks, num_clusters); + bool load_twice = rows_per_block > register_max_rows / blocks_per_sm; + + // Wave utilization: the percent of SMs that are used for each wave + // For example, SM_COUNT=100 and VIRTUAL_CLUSTER_SIZE=64, + // if BLOCKS_PER_SM=1, num_clusters=1, wave_util=64%; + // if BLOCKS_PER_SM=2, num_clusters=3, wave_util=96%. + // This helps select a good number of BLOCKS_PER_SM + int wave_util = 10000 * std::min(num_tasks, num_clusters) * virtual_cluster_size / + (blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)); + + decltype(best_candidate) candidate = { + true, + !load_twice, // Prefer no load twice + !(num_waves >= 2 && + blocks_per_sm == + 1), // When there are multiple waves, prefer multiple blocks per SM to ensure overlapping + -num_waves, // Prefer fewer waves + std::min(9000, wave_util), // Prefer high wave utilization + -blocks_per_sm, // Prefer fewer blocks per SM in order to reduce threads overhead + }; + if (candidate > best_candidate) { + // Assign each element respectively to comply with C++17 + std::get<0>(best_candidate) = std::get<0>(candidate); + std::get<1>(best_candidate) = std::get<1>(candidate); + std::get<2>(best_candidate) = std::get<2>(candidate); + std::get<3>(best_candidate) = std::get<3>(candidate); + std::get<4>(best_candidate) = std::get<4>(candidate); + std::get<5>(best_candidate) = std::get<5>(candidate); + static_assert(std::tuple_size::value == 6, "missing assignments"); + + BLOCKS_PER_SM = blocks_per_sm; + ROWS_PER_BLOCK = rows_per_block; } - - LOAD_TWICE = ROWS_PER_BLOCK > register_max_rows / BLOCKS_PER_SM; + } } - int c_per_cluster = lcm(CPG, C_PER_BLOCK); - int virtual_cluster_size = (c_per_cluster / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); - - // The occupancy is affected if cluster size is large. - // For example, on H100, when gridDim=128 and each block occupies the whole SM, - // if cluster is not used, all blocks can be active simultaneously. - // if cluster size is 16, not all blocks can be active simultaneously (which can be queried by cudaOccupancyMaxActiveClusters), - // so there will be two waves which impacts efficiency. - // When SM_MARGIN is set, no cluster should be used because other kernels may occupy a part of the cluster. - bool HARDWARE_CLUSTER = virtual_cluster_size <= 2 && virtual_cluster_size != 1 && SM_MARGIN == 0; - - int MAX_VEC_BYTES = 8; // Sometimes 4 or 16 is better, but there is no trivial way to select the best vectorization size. - int VEC_ELEMS = std::min(gcd(MAX_VEC_BYTES / (int)sizeof(T), C_PER_BLOCK), - gcd(MAX_VEC_BYTES / (int)sizeof(T), ROWS_PER_BLOCK * C_PER_BLOCK / BLOCK_DIM_X)); - - return std::make_tuple( - BLOCK_DIM_X, - C_PER_BLOCK, - ROWS_PER_BLOCK, - VEC_ELEMS, - LOAD_TWICE, - BLOCKS_PER_SM, - HARDWARE_CLUSTER, - wgrad_sync_method - ); + LOAD_TWICE = ROWS_PER_BLOCK > register_max_rows / BLOCKS_PER_SM; + } + + int c_per_cluster = lcm(CPG, C_PER_BLOCK); + int virtual_cluster_size = (c_per_cluster / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); + + // The occupancy is affected if cluster size is large. + // For example, on H100, when gridDim=128 and each block occupies the whole SM, + // if cluster is not used, all blocks can be active simultaneously. + // if cluster size is 16, not all blocks can be active simultaneously (which can be queried by + // cudaOccupancyMaxActiveClusters), + // so there will be two waves which impacts efficiency. + // When SM_MARGIN is set, no cluster should be used because other kernels may occupy a part of the cluster. + bool HARDWARE_CLUSTER = virtual_cluster_size <= 2 && virtual_cluster_size != 1 && SM_MARGIN == 0; + + int MAX_VEC_BYTES = + 8; // Sometimes 4 or 16 is better, but there is no trivial way to select the best vectorization size. + int VEC_ELEMS = std::min(gcd(MAX_VEC_BYTES / (int)sizeof(T), C_PER_BLOCK), + gcd(MAX_VEC_BYTES / (int)sizeof(T), ROWS_PER_BLOCK * C_PER_BLOCK / BLOCK_DIM_X)); + + return std::make_tuple(BLOCK_DIM_X, C_PER_BLOCK, ROWS_PER_BLOCK, VEC_ELEMS, LOAD_TWICE, BLOCKS_PER_SM, + HARDWARE_CLUSTER, wgrad_sync_method); } // Save compilation time for unused CUDA_ARCHs -// For each template argument from DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT, the kernel is only compiled for the corresponding CUDA_ARCH -template +// For each template argument from DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT, the kernel is only compiled for the +// corresponding CUDA_ARCH +template class CompileCondition { -public: - __host__ __device__ static constexpr bool matches() { + public: + __host__ __device__ static constexpr bool matches() { #if defined(__CUDA_ARCH__) - return __CUDA_ARCH__ == EFFECTIVE_CUDA_ARCH; + return __CUDA_ARCH__ == EFFECTIVE_CUDA_ARCH; #else - return false; + return false; #endif - } + } }; -template +template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)) { - if (out == x) { - throw std::invalid_argument("not __restrict__"); - } - - cudaDeviceProp const &deviceProp = get_device_prop(device_id); - int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; - int sm_count = deviceProp.multiProcessorCount; - - DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { - DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { - DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { - if (hw != HW) { - throw std::invalid_argument("wrong HW"); - } - if (num_groups * channels_per_group != C) { - throw std::invalid_argument("wrong C"); - } - if (num_groups != G) { - throw std::invalid_argument("wrong G"); - } - if (silu != SILU) { - throw std::invalid_argument("wrong SILU"); - } - if (n < LB_N) { - throw std::invalid_argument("wrong LB_N"); - } - if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { - throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); - } - if (sm_count < LB_SM_COUNT) { - throw std::invalid_argument("wrong LB_SM_COUNT"); - } - if (sm_margin != SM_MARGIN) { - throw std::invalid_argument("wrong SM_MARGIN"); - } - constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX - - constexpr int CPG = C / G; - - constexpr auto params = compute_gn_params(); - constexpr int BLOCK_DIM_X = std::get<0>(params); - constexpr int C_PER_BLOCK = std::get<1>(params); - constexpr int ROWS_PER_BLOCK = std::get<2>(params); - constexpr int VEC_ELEMS = std::get<3>(params); - constexpr bool LOAD_TWICE = std::get<4>(params); - constexpr int BLOCKS_PER_SM = std::get<5>(params); - constexpr bool HARDWARE_CLUSTER = std::get<6>(params); - - constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); - constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); - constexpr int NUM_VIRTUAL_CLUSTERS = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; - constexpr bool PERSISTENT = !HARDWARE_CLUSTER && VIRTUAL_CLUSTER_SIZE >= 2; // Only virtual cluster sync (not include hardware cluster sync) requires PERSISTENT kernels - - if (meta_ptr) { - constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; - meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2; - meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS; - meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; - meta_ptr->C_PER_BLOCK = C_PER_BLOCK; - meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; - meta_ptr->VEC_ELEMS = VEC_ELEMS; - meta_ptr->LOAD_TWICE = LOAD_TWICE; - meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; - meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; - meta_ptr->wgrad_sync_method = (int)WGRAD_SYNC_UNSPECIFIED; - } - if (meta_only) { - return; - } + if (out == x) { + throw std::invalid_argument("not __restrict__"); + } + + cudaDeviceProp const &deviceProp = get_device_prop(device_id); + int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; + int sm_count = deviceProp.multiProcessorCount; + + DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { + DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { + DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { + if (hw != HW) { + throw std::invalid_argument("wrong HW"); + } + if (num_groups * channels_per_group != C) { + throw std::invalid_argument("wrong C"); + } + if (num_groups != G) { + throw std::invalid_argument("wrong G"); + } + if (silu != SILU) { + throw std::invalid_argument("wrong SILU"); + } + if (n < LB_N) { + throw std::invalid_argument("wrong LB_N"); + } + if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { + throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); + } + if (sm_count < LB_SM_COUNT) { + throw std::invalid_argument("wrong LB_SM_COUNT"); + } + if (sm_margin != SM_MARGIN) { + throw std::invalid_argument("wrong SM_MARGIN"); + } + constexpr int EFFECTIVE_CUDA_ARCH = + std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX + + constexpr int CPG = C / G; + + constexpr auto params = compute_gn_params(); + constexpr int BLOCK_DIM_X = std::get<0>(params); + constexpr int C_PER_BLOCK = std::get<1>(params); + constexpr int ROWS_PER_BLOCK = std::get<2>(params); + constexpr int VEC_ELEMS = std::get<3>(params); + constexpr bool LOAD_TWICE = std::get<4>(params); + constexpr int BLOCKS_PER_SM = std::get<5>(params); + constexpr bool HARDWARE_CLUSTER = std::get<6>(params); + + constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); + constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); + constexpr int NUM_VIRTUAL_CLUSTERS = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; + constexpr bool PERSISTENT = + !HARDWARE_CLUSTER && + VIRTUAL_CLUSTER_SIZE >= + 2; // Only virtual cluster sync (not include hardware cluster sync) requires PERSISTENT kernels + + if (meta_ptr) { + constexpr int MAX_NUM_GROUPS_PER_BLOCK = + C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; + meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2; + meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS; + meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; + meta_ptr->C_PER_BLOCK = C_PER_BLOCK; + meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; + meta_ptr->VEC_ELEMS = VEC_ELEMS; + meta_ptr->LOAD_TWICE = LOAD_TWICE; + meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; + meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; + meta_ptr->wgrad_sync_method = (int)WGRAD_SYNC_UNSPECIFIED; + } + if (meta_only) { + return; + } - cudaLaunchConfig_t config = {0}; - config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1); - config.blockDim = BLOCK_DIM_X; - config.stream = stream; - - cudaLaunchAttribute attribute[2]; - if constexpr (HARDWARE_CLUSTER) { - attribute[0].id = cudaLaunchAttributeClusterDimension; - attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; // Cluster size in X-dimension - attribute[0].val.clusterDim.y = 1; - attribute[0].val.clusterDim.z = 1; - config.attrs = attribute; - config.numAttrs++; - } - if constexpr (PERSISTENT) { - attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; - attribute[config.numAttrs].val.cooperative = 1; - config.attrs = attribute; - config.numAttrs++; - } + cudaLaunchConfig_t config = {0}; + config.gridDim = dim3( + VIRTUAL_CLUSTER_SIZE, + PERSISTENT ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1); + config.blockDim = BLOCK_DIM_X; + config.stream = stream; + + cudaLaunchAttribute attribute[2]; + if constexpr (HARDWARE_CLUSTER) { + attribute[0].id = cudaLaunchAttributeClusterDimension; + attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; // Cluster size in X-dimension + attribute[0].val.clusterDim.y = 1; + attribute[0].val.clusterDim.z = 1; + config.attrs = attribute; + config.numAttrs++; + } + if constexpr (PERSISTENT) { + attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; + attribute[config.numAttrs].val.cooperative = 1; + config.attrs = attribute; + config.numAttrs++; + } - auto kernel = &gn_cuda_kernel >; - if constexpr (HARDWARE_CLUSTER) { - if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { - CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); - } - int max_cluster_size; - int active_clusters; - CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config)); - if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { - attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; - CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config)); - } - if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { - attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; - } else { - // Fallback to cooperative groups because hardware cluster cannot be active simultaneously - constexpr bool HARDWARE_CLUSTER_NEW = false; - constexpr bool PERSISTENT_NEW = !HARDWARE_CLUSTER_NEW && VIRTUAL_CLUSTER_SIZE >= 2; - config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT_NEW ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1); - config.attrs = nullptr; - config.numAttrs = 0; - if constexpr (PERSISTENT_NEW) { - attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; - attribute[config.numAttrs].val.cooperative = 1; - config.attrs = attribute; - config.numAttrs++; - } - kernel = &gn_cuda_kernel >; - } + auto kernel = &gn_cuda_kernel >; + if constexpr (HARDWARE_CLUSTER) { + if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); + } + int max_cluster_size; + int active_clusters; + CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config)); + if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { + attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; + CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config)); + } + if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && + (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { + attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; + } else { + // Fallback to cooperative groups because hardware cluster cannot be active simultaneously + constexpr bool HARDWARE_CLUSTER_NEW = false; + constexpr bool PERSISTENT_NEW = !HARDWARE_CLUSTER_NEW && VIRTUAL_CLUSTER_SIZE >= 2; + config.gridDim = dim3( + VIRTUAL_CLUSTER_SIZE, + PERSISTENT_NEW ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), + 1); + config.attrs = nullptr; + config.numAttrs = 0; + if constexpr (PERSISTENT_NEW) { + attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; + attribute[config.numAttrs].val.cooperative = 1; + config.attrs = attribute; + config.numAttrs++; } - CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, out, x, w, b, eps, n, mean_var_out, red_buffer, barrier)); - }); - }); + kernel = &gn_cuda_kernel >; + } + } + CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, out, x, w, b, eps, n, mean_var_out, red_buffer, barrier)); + }); }); + }); } - -template +template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)) { - if (grad_input == grad_output || grad_input == x) { - throw std::invalid_argument("not __restrict__"); - } - - cudaDeviceProp const &deviceProp = get_device_prop(device_id); - int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; - int sm_count = deviceProp.multiProcessorCount; - - DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { - DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { - DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { - if (hw != HW) { - throw std::invalid_argument("wrong HW"); - } - if (num_groups * channels_per_group != C) { - throw std::invalid_argument("wrong C"); - } - if (num_groups != G) { - throw std::invalid_argument("wrong G"); - } - if (silu != SILU) { - throw std::invalid_argument("wrong SILU"); - } - if (n < LB_N) { - throw std::invalid_argument("wrong LB_N"); - } - if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { - throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); - } - if (sm_count < LB_SM_COUNT) { - throw std::invalid_argument("wrong LB_SM_COUNT"); - } - if (sm_margin != SM_MARGIN) { - throw std::invalid_argument("wrong SM_MARGIN"); - } - constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX - - constexpr bool REQUIRES_WGRAD = true; - constexpr int CPG = C / G; - - constexpr auto params = compute_gn_params(); - constexpr int BLOCK_DIM_X = std::get<0>(params); - constexpr int C_PER_BLOCK = std::get<1>(params); - constexpr int ROWS_PER_BLOCK = std::get<2>(params); - constexpr int VEC_ELEMS = std::get<3>(params); - constexpr bool LOAD_TWICE = std::get<4>(params); - constexpr int BLOCKS_PER_SM = std::get<5>(params); - constexpr bool HARDWARE_CLUSTER = std::get<6>(params); - constexpr WgradSyncMethod wgrad_sync_method_hint = std::get<7>(params); - - constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); - constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); - constexpr int NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; - - // PERSISTENT is required because wgrad reduction requires synchronization. - // TODO: specilize for the case that REQUIRES_WGRAD == false - constexpr bool PERSISTENT = true; - - // Determine whether to align each virtual cluster to a fixed range of channels - // If aligned, WGRAD_REUSE_SUM_SYNC_GROUP can be used, then less local wgrad memory is used (leave more room for compiler - // optimizations), and wgrad reduction is more efficient. - // However, aligning can cause low occupancy. - // There is a trade-off, and the condition to align is `NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER)` - constexpr WgradSyncMethod wgrad_sync_method = - wgrad_sync_method_hint == WGRAD_SYNC_UNSPECIFIED ? - NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER) || NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED % (C / C_PER_CLUSTER) == 0 ? - (HARDWARE_CLUSTER ? WGRAD_ARRIVE_AND_WAIT_GROUP : WGRAD_REUSE_SUM_SYNC_GROUP) : - WGRAD_REUSE_SUM_SYNC_GRID : - wgrad_sync_method_hint; - constexpr int NUM_VIRTUAL_CLUSTERS = - wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP ? - NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED / (C / C_PER_CLUSTER) * (C / C_PER_CLUSTER) : - NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED; - - if (meta_ptr) { - constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; - meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2 + - std::max(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) * (HW / ROWS_PER_BLOCK) * C * 2; - meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS + C / C_PER_CLUSTER; - meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; - meta_ptr->C_PER_BLOCK = C_PER_BLOCK; - meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; - meta_ptr->VEC_ELEMS = VEC_ELEMS; - meta_ptr->LOAD_TWICE = LOAD_TWICE; - meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; - meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; - meta_ptr->wgrad_sync_method = (int)wgrad_sync_method; - } - if (meta_only) { - return; - } + if (grad_input == grad_output || grad_input == x) { + throw std::invalid_argument("not __restrict__"); + } + + cudaDeviceProp const &deviceProp = get_device_prop(device_id); + int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; + int sm_count = deviceProp.multiProcessorCount; + + DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { + DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { + DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { + if (hw != HW) { + throw std::invalid_argument("wrong HW"); + } + if (num_groups * channels_per_group != C) { + throw std::invalid_argument("wrong C"); + } + if (num_groups != G) { + throw std::invalid_argument("wrong G"); + } + if (silu != SILU) { + throw std::invalid_argument("wrong SILU"); + } + if (n < LB_N) { + throw std::invalid_argument("wrong LB_N"); + } + if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { + throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); + } + if (sm_count < LB_SM_COUNT) { + throw std::invalid_argument("wrong LB_SM_COUNT"); + } + if (sm_margin != SM_MARGIN) { + throw std::invalid_argument("wrong SM_MARGIN"); + } + constexpr int EFFECTIVE_CUDA_ARCH = + std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX + + constexpr bool REQUIRES_WGRAD = true; + constexpr int CPG = C / G; + + constexpr auto params = compute_gn_params(); + constexpr int BLOCK_DIM_X = std::get<0>(params); + constexpr int C_PER_BLOCK = std::get<1>(params); + constexpr int ROWS_PER_BLOCK = std::get<2>(params); + constexpr int VEC_ELEMS = std::get<3>(params); + constexpr bool LOAD_TWICE = std::get<4>(params); + constexpr int BLOCKS_PER_SM = std::get<5>(params); + constexpr bool HARDWARE_CLUSTER = std::get<6>(params); + constexpr WgradSyncMethod wgrad_sync_method_hint = std::get<7>(params); + + constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); + constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); + constexpr int NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED = + ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; + + // PERSISTENT is required because wgrad reduction requires synchronization. + // TODO: specilize for the case that REQUIRES_WGRAD == false + constexpr bool PERSISTENT = true; + + // Determine whether to align each virtual cluster to a fixed range of channels + // If aligned, WGRAD_REUSE_SUM_SYNC_GROUP can be used, then less local wgrad memory is used (leave more room + // for compiler + // optimizations), and wgrad reduction is more efficient. + // However, aligning can cause low occupancy. + // There is a trade-off, and the condition to align is `NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / + // C_PER_CLUSTER)` + constexpr WgradSyncMethod wgrad_sync_method = + wgrad_sync_method_hint == WGRAD_SYNC_UNSPECIFIED + ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER) || + NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED % (C / C_PER_CLUSTER) == 0 + ? (HARDWARE_CLUSTER ? WGRAD_ARRIVE_AND_WAIT_GROUP : WGRAD_REUSE_SUM_SYNC_GROUP) + : WGRAD_REUSE_SUM_SYNC_GRID + : wgrad_sync_method_hint; + constexpr int NUM_VIRTUAL_CLUSTERS = + wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP + ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED / (C / C_PER_CLUSTER) * (C / C_PER_CLUSTER) + : NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED; + + if (meta_ptr) { + constexpr int MAX_NUM_GROUPS_PER_BLOCK = + C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; + meta_ptr->red_buffer_size = + 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2 + + std::max(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) * (HW / ROWS_PER_BLOCK) * C * 2; + meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS + C / C_PER_CLUSTER; + meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; + meta_ptr->C_PER_BLOCK = C_PER_BLOCK; + meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; + meta_ptr->VEC_ELEMS = VEC_ELEMS; + meta_ptr->LOAD_TWICE = LOAD_TWICE; + meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; + meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; + meta_ptr->wgrad_sync_method = (int)wgrad_sync_method; + } + if (meta_only) { + return; + } - cudaLaunchConfig_t config = {0}; - config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? NUM_VIRTUAL_CLUSTERS : n * (C / C_PER_CLUSTER), 1); - config.blockDim = BLOCK_DIM_X; - config.stream = stream; - - cudaLaunchAttribute attribute[2]; - if constexpr (HARDWARE_CLUSTER) { - attribute[0].id = cudaLaunchAttributeClusterDimension; - attribute[0].val.clusterDim.x = 1; // Cluster size in X-dimension - attribute[0].val.clusterDim.y = 1; - attribute[0].val.clusterDim.z = 1; - config.attrs = attribute; - config.numAttrs++; - } - if constexpr (PERSISTENT) { - attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; - attribute[config.numAttrs].val.cooperative = 1; - config.attrs = attribute; - config.numAttrs++; - } + cudaLaunchConfig_t config = {0}; + config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? NUM_VIRTUAL_CLUSTERS : n * (C / C_PER_CLUSTER), 1); + config.blockDim = BLOCK_DIM_X; + config.stream = stream; + + cudaLaunchAttribute attribute[2]; + if constexpr (HARDWARE_CLUSTER) { + attribute[0].id = cudaLaunchAttributeClusterDimension; + attribute[0].val.clusterDim.x = 1; // Cluster size in X-dimension + attribute[0].val.clusterDim.y = 1; + attribute[0].val.clusterDim.z = 1; + config.attrs = attribute; + config.numAttrs++; + } + if constexpr (PERSISTENT) { + attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; + attribute[config.numAttrs].val.cooperative = 1; + config.attrs = attribute; + config.numAttrs++; + } - auto kernel = &gn_bwd_cuda_kernel >; - if constexpr (HARDWARE_CLUSTER) { - if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { - CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); - } - int max_cluster_size; - int active_clusters; - CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config)); - if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { - attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; - CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config)); - } - if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { - attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; - } else { - // Fallback to cooperative groups for dgrad computation because hardware cluster cannot be active simultaneously - attribute[0].val.clusterDim.x = 1; - kernel = &gn_bwd_cuda_kernel >; - } - } - CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, n, red_buffer, barrier)); - }); - }); + auto kernel = + &gn_bwd_cuda_kernel >; + if constexpr (HARDWARE_CLUSTER) { + if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); + } + int max_cluster_size; + int active_clusters; + CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config)); + if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { + attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; + CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config)); + } + if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && + (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { + attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; + } else { + // Fallback to cooperative groups for dgrad computation because hardware cluster cannot be active + // simultaneously + attribute[0].val.clusterDim.x = 1; + kernel = + &gn_bwd_cuda_kernel >; + } + } + CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, grad_input, grad_weight, grad_bias, grad_output, x, w, b, + mean_var, eps, n, red_buffer, barrier)); + }); }); + }); } -#define GN_CUDA_INST_DEFINE(HW, C) \ - template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ - template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ - template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ - template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ - template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ - template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ - template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); \ - template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); +#define GN_CUDA_INST_DEFINE(HW, C) \ + template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ + template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ + template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ + template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ + template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ + template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ + template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); \ + template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); } // namespace group_norm_v2 diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu index 5e5e721f8..9ff6d256a 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 1280) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu index 9793c8fad..2794ad378 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 1920) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu index d1b1b5d10..1a0e5bff6 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 320) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu index c79f5b72a..ce2bacd64 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 640) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu index 40f3cfb32..6e9b69f9c 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 960) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu index 5c3265ab8..ac11c7ca5 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 1280) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu index eda3646c0..b3f144dc0 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 1920) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu index 69c9de2de..323f11917 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 2560) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu index 63e5de07a..5a5dc0e22 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 640) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu index a24c88cf1..2f818e840 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 320) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu index 83c1e58bc..609cbc9c0 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 640) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu index f6fea9470..a052a2751 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 960) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu index e92de897e..7ed63f23e 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(64, 1280) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu index 37212deb6..208077687 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu @@ -1,6 +1,5 @@ #include "gn_cuda_host_template.cuh" - namespace group_norm_v2 { GN_CUDA_INST_DEFINE(64, 2560) diff --git a/apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh b/apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh index 46d36f957..f92780596 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh +++ b/apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh @@ -4,1106 +4,1213 @@ #include "gn_utils.hpp" - namespace group_norm_v2 { namespace cg = cooperative_groups; -template -inline constexpr T up_div(T a, T b) { return (a + b - 1) / b; } +template +inline constexpr T up_div(T a, T b) { + return (a + b - 1) / b; +} -template -inline constexpr T round_up(T a, T b) { return up_div(a, b) * b; } +template +inline constexpr T round_up(T a, T b) { + return up_div(a, b) * b; +} inline constexpr unsigned round_up_pow2(unsigned x) { - int log = 0; - x--; - while (x) { - x /= 2; - log++; - } - return 1U << log; + int log = 0; + x--; + while (x) { + x /= 2; + log++; + } + return 1U << log; } -inline constexpr unsigned round_down_pow2(unsigned x) { - return round_up_pow2(x + 1) / 2; -} +inline constexpr unsigned round_down_pow2(unsigned x) { return round_up_pow2(x + 1) / 2; } -template +template inline constexpr T gcd(T a, T b) { - while (b != 0) { - int t = b; - b = a % b; - a = t; - } - return a; + while (b != 0) { + int t = b; + b = a % b; + a = t; + } + return a; } -template -inline constexpr T lcm(T a, T b) { return (a * b) / gcd(a, b); } +template +inline constexpr T lcm(T a, T b) { + return (a * b) / gcd(a, b); +} -template +template inline constexpr T relative_prime(T x, T min) { - int p = min; - while (gcd(p, x) != 1) { - p++; - } - return p; + int p = min; + while (gcd(p, x) != 1) { + p++; + } + return p; } -template +template inline constexpr T max_divisor(T x, T max) { - int p = max; - while (x % p != 0) { - p--; - } - return p; + int p = max; + while (x % p != 0) { + p--; + } + return p; } constexpr unsigned FINAL_MASK = 0xffffffff; -template +template __device__ void virtual_cluster_sync(unsigned int *barrier) { - if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { - __syncthreads(); - } else if constexpr (HARDWARE_CLUSTER) { - cg::this_cluster().sync(); - } else { - static_assert(PERSISTENT, "potential deadlock"); - volatile unsigned int *arrived = &barrier[blockIdx.y]; - __syncthreads(); - if (threadIdx.x == 0) { - unsigned int expected = VIRTUAL_CLUSTER_SIZE; - bool gpu_master = blockIdx.x == 0; - unsigned int nb = 1; - if (gpu_master) { - nb = 0x80000000 - (expected - 1); - } - unsigned int oldArrive; - asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory"); - unsigned int current_arrive; - do { - asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory"); - } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); - } - __syncthreads(); + if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { + __syncthreads(); + } else if constexpr (HARDWARE_CLUSTER) { + cg::this_cluster().sync(); + } else { + static_assert(PERSISTENT, "potential deadlock"); + volatile unsigned int *arrived = &barrier[blockIdx.y]; + __syncthreads(); + if (threadIdx.x == 0) { + unsigned int expected = VIRTUAL_CLUSTER_SIZE; + bool gpu_master = blockIdx.x == 0; + unsigned int nb = 1; + if (gpu_master) { + nb = 0x80000000 - (expected - 1); + } + unsigned int oldArrive; + asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" + : "=r"(oldArrive) + : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived), "r"(nb) + : "memory"); + unsigned int current_arrive; + do { + asm volatile("ld.acquire.gpu.u32 %0,[%1];" + : "=r"(current_arrive) + : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) + : "memory"); + } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); } + __syncthreads(); + } } -template +template __device__ unsigned int group_barrier_arrive(unsigned int *barrier, bool gpu_master) { - static_assert(PERSISTENT, "potential deadlock"); - volatile unsigned int *arrived = &barrier[0]; - __syncthreads(); - if (threadIdx.x == 0) { - unsigned int expected = NUM_BLOCKS; - unsigned int nb = 1; - if (gpu_master) { - nb = 0x80000000 - (expected - 1); - } - unsigned int oldArrive; - asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory"); - return oldArrive; - } else { - return 0; + static_assert(PERSISTENT, "potential deadlock"); + volatile unsigned int *arrived = &barrier[0]; + __syncthreads(); + if (threadIdx.x == 0) { + unsigned int expected = NUM_BLOCKS; + unsigned int nb = 1; + if (gpu_master) { + nb = 0x80000000 - (expected - 1); } + unsigned int oldArrive; + asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" + : "=r"(oldArrive) + : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived), "r"(nb) + : "memory"); + return oldArrive; + } else { + return 0; + } } __device__ inline void group_barrier_wait(unsigned int *barrier, unsigned int oldArrive) { - volatile unsigned int *arrived = &barrier[0]; - if (threadIdx.x == 0) { - unsigned int current_arrive; - do { - asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory"); - } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); - } - __syncthreads(); + volatile unsigned int *arrived = &barrier[0]; + if (threadIdx.x == 0) { + unsigned int current_arrive; + do { + asm volatile("ld.acquire.gpu.u32 %0,[%1];" + : "=r"(current_arrive) + : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) + : "memory"); + } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); + } + __syncthreads(); } // Calculate `n` (batch id) and `c` (channel range id) for each loop -template +template class NCScheduler; -template +template class NCScheduler { -public: - __device__ NCScheduler(int64_t n) { - nc_loop_ = blockIdx.y; - at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); - } - __device__ auto get_nc() { - int64_t n_loop = nc_loop_ / (C / C_PER_CLUSTER); - int c_loop = nc_loop_ % (C / C_PER_CLUSTER); - return std::make_tuple(n_loop, c_loop); + public: + __device__ NCScheduler(int64_t n) { + nc_loop_ = blockIdx.y; + at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); + } + __device__ auto get_nc() { + int64_t n_loop = nc_loop_ / (C / C_PER_CLUSTER); + int c_loop = nc_loop_ % (C / C_PER_CLUSTER); + return std::make_tuple(n_loop, c_loop); + } + __device__ void next(int64_t n) { + if constexpr (PERSISTENT) { + nc_loop_ += NUM_VIRTUAL_CLUSTERS; + at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); } - __device__ void next(int64_t n) { - if constexpr (PERSISTENT) { - nc_loop_ += NUM_VIRTUAL_CLUSTERS; - at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); - } - } - __device__ bool at_end(int64_t n) { - return !PERSISTENT || at_end_; - } -private: - int64_t nc_loop_; - bool at_end_; + } + __device__ bool at_end(int64_t n) { return !PERSISTENT || at_end_; } + + private: + int64_t nc_loop_; + bool at_end_; }; -template +template class NCScheduler { -public: - __device__ NCScheduler(int64_t n) { - n_loop_ = blockIdx.y / (C / C_PER_CLUSTER); - c_loop_ = blockIdx.y % (C / C_PER_CLUSTER); + public: + __device__ NCScheduler(int64_t n) { + n_loop_ = blockIdx.y / (C / C_PER_CLUSTER); + c_loop_ = blockIdx.y % (C / C_PER_CLUSTER); + } + __device__ auto get_nc() { return std::make_tuple(n_loop_, c_loop_); } + __device__ void next(int64_t n) { + if constexpr (PERSISTENT) { + n_loop_ += NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER); } - __device__ auto get_nc() { - return std::make_tuple(n_loop_, c_loop_); - } - __device__ void next(int64_t n) { - if constexpr (PERSISTENT) { - n_loop_ += NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER); - } - } - __device__ bool at_end(int64_t n) { - return !PERSISTENT || n_loop_ >= n; - } -private: - int64_t n_loop_; - int c_loop_; + } + __device__ bool at_end(int64_t n) { return !PERSISTENT || n_loop_ >= n; } + + private: + int64_t n_loop_; + int c_loop_; }; class CompileConditionAlwaysTrue { -public: - __device__ static constexpr bool matches() { - return true; - } + public: + __device__ static constexpr bool matches() { return true; } }; -template -__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_cuda_kernel(T *__restrict__ out, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b, float eps, int64_t n, float *__restrict__ mean_var_out, float *__restrict__ red_buffer, unsigned *__restrict__ barrier) { - // Procedure Overview - // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) - // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is used) - // 3. Group sum: read from gmem, write mean&var to smem - // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem - - static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); - - constexpr int C = G * CPG; - static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); - static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); - static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); - static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), "inefficient configuration, please reduce C_PER_CLUSTER"); - - static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); - struct alignas(VEC_ELEMS * sizeof(T)) U { - T data[VEC_ELEMS]; - }; - - auto compute_mean_var = [&](float2 sum) { - float mean = sum.x / (HW * CPG); - float var = std::max(0.f, sum.y / (HW * CPG) - mean * mean); - return float2{mean, var}; - }; - - static_assert(HW % ROWS_PER_BLOCK == 0, "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); - constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; - constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); - constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; - constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; - int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; - int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; - - if constexpr (CompileCondition::matches()) { - int step = 0; - constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; - NCScheduler nc_scheduler(n); - while (true) { // TODO: unroll the loop - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - break; - } +template +__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_cuda_kernel( + T *__restrict__ out, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b, float eps, + int64_t n, float *__restrict__ mean_var_out, float *__restrict__ red_buffer, unsigned *__restrict__ barrier) { + // Procedure Overview + // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) + // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is + // used) + // 3. Group sum: read from gmem, write mean&var to smem + // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem + + static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); + + constexpr int C = G * CPG; + static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); + static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); + static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); + static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), + "inefficient configuration, please reduce C_PER_CLUSTER"); + + static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); + struct alignas(VEC_ELEMS * sizeof(T)) U { + T data[VEC_ELEMS]; + }; + + auto compute_mean_var = [&](float2 sum) { + float mean = sum.x / (HW * CPG); + float var = std::max(0.f, sum.y / (HW * CPG) - mean * mean); + return float2{mean, var}; + }; + + static_assert(HW % ROWS_PER_BLOCK == 0, + "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); + constexpr int MAX_NUM_GROUPS_PER_BLOCK = + C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; + constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); + constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; + constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; + int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; + int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; + + if constexpr (CompileCondition::matches()) { + int step = 0; + constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; + NCScheduler nc_scheduler(n); + while (true) { // TODO: unroll the loop + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + break; + } + } + auto [n_loop, c_loop] = nc_scheduler.get_nc(); + if constexpr (PERSISTENT) { + nc_scheduler.next(n); + } + static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); + static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, + "each block should load one or more C_PER_BLOCK at once"); + constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; + static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); + int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; + int block_group_start = block_channel_start / CPG; + int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; + U frag[ROWS_PER_BLOCK / ROWS_PER_IO]; + + // GCD_VEC_CPG is an important constant that determines how many channels can be merged in reduction computation + // For example, VEC_ELEMS=4 and CPG=10, then GCD_VEC_CPG=2, + // so we need to store only 2 sums on each thread, and compute only 2 mean&var for each thread. + constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); + + // If each block handles only one group, run warpReduce and store the sum to `sum_per_channel_single_group`; + // otherwise store (VEC_ELEMS / GCD_VEC_CPG) sums to `sum_per_channel_multi_group`, where `relative_prime` is used + // for swizzle. + constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; + [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; + [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime( + 128 / (int)sizeof(float2), ROWS_PER_IO)]; + + if constexpr (LOAD_TWICE) { + float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + U val = *reinterpret_cast(&x[input_idx]); + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + float2 sum = frag_sum_per_channel[i]; + for (int k = 0; k < GCD_VEC_CPG; k++) { + sum.x += (float)val.data[i * GCD_VEC_CPG + k]; + sum.y += (float)val.data[i * GCD_VEC_CPG + k] * (float)val.data[i * GCD_VEC_CPG + k]; } - auto [n_loop, c_loop] = nc_scheduler.get_nc(); - if constexpr (PERSISTENT) { - nc_scheduler.next(n); + frag_sum_per_channel[i] = sum; + } + } + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + if constexpr (SINGLE_GROUP_PER_BLOCK) { + for (int mask = 16; mask > 0; mask >>= 1) { + frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); + frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); } - static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); - static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, "each block should load one or more C_PER_BLOCK at once"); - constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; - static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); - int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; - int block_group_start = block_channel_start / CPG; - int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; - U frag[ROWS_PER_BLOCK / ROWS_PER_IO]; - - // GCD_VEC_CPG is an important constant that determines how many channels can be merged in reduction computation - // For example, VEC_ELEMS=4 and CPG=10, then GCD_VEC_CPG=2, - // so we need to store only 2 sums on each thread, and compute only 2 mean&var for each thread. - constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); - - // If each block handles only one group, run warpReduce and store the sum to `sum_per_channel_single_group`; - // otherwise store (VEC_ELEMS / GCD_VEC_CPG) sums to `sum_per_channel_multi_group`, where `relative_prime` is used for swizzle. - constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; - [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; - [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(128 / (int)sizeof(float2), ROWS_PER_IO)]; - - if constexpr (LOAD_TWICE) { - float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - U val = *reinterpret_cast(&x[input_idx]); - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - float2 sum = frag_sum_per_channel[i]; - for (int k = 0; k < GCD_VEC_CPG; k++) { - sum.x += (float)val.data[i * GCD_VEC_CPG + k]; - sum.y += (float)val.data[i * GCD_VEC_CPG + k] * (float)val.data[i * GCD_VEC_CPG + k]; - } - frag_sum_per_channel[i] = sum; - } - } - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - if constexpr (SINGLE_GROUP_PER_BLOCK) { - for (int mask = 16; mask > 0; mask >>= 1) { - frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); - frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); - } - static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); - if (threadIdx.x % 32 == 0) { - sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; - } - } else { - sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)][threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; - } - } - __syncthreads(); - } else { - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - frag[j] = *reinterpret_cast(&x[input_idx]); - } - - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - float2 sum = {0.f, 0.f}; - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - for (int k = 0; k < GCD_VEC_CPG; k++) { - sum.x += (float)frag[j].data[i * GCD_VEC_CPG + k]; - sum.y += (float)frag[j].data[i * GCD_VEC_CPG + k] * (float)frag[j].data[i * GCD_VEC_CPG + k]; - } - } - if constexpr (SINGLE_GROUP_PER_BLOCK) { - for (int mask = 16; mask > 0; mask >>= 1) { - sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); - sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); - } - static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); - if (threadIdx.x % 32 == 0) { - sum_per_channel_single_group[threadIdx.x / 32] = sum; - } - } else { - sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)][threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; - } - } - __syncthreads(); + static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); + if (threadIdx.x % 32 == 0) { + sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; } + } else { + sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] + [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; + } + } + __syncthreads(); + } else { + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + frag[j] = *reinterpret_cast(&x[input_idx]); + } - U uw = *reinterpret_cast(&w[thread_channel_start]); - U ub = *reinterpret_cast(&b[thread_channel_start]); - - // Three cases for the red_buffer: - // - Block sync (VIRTUAL_CLUSTER_SIZE=1): use shared memory - // - Virtual cluster sync with HARDWARE_CLUSTER: use distributed shared memory - // - Virtual cluster sync without HARDWARE_CLUSTER: use global memory, i.e., `red_buffer` - constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; - - // Specialize for the case that each group is handled by only one block - // For common cases, blockSum produces partial sum and stores it to the red_buffer, and groupSum produces mean&var - // For the special case, blockSum produces mean&var directly - constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = VIRTUAL_CLUSTER_SIZE == 1 && MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented - - [[maybe_unused]] __align__(16) __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; - - // Block sum - if constexpr (SINGLE_GROUP_PER_BLOCK) { - // block reduce - if (threadIdx.x < 32) { - float2 sum_local_group = threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; - constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); - for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { - sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); - sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); - } - if (threadIdx.x == 0) { - if constexpr (USE_SHARED_RED_BUFFER) { - if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - shared_red_buffer[0] = compute_mean_var(sum_local_group); - } else { - shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; - } - } else { - *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - virtual_block_idx_y) * 2]) = sum_local_group; - } - } - } - } else { - // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) - constexpr int THREADS_PER_GROUP = std::min(std::min(32U, - round_up_pow2(ROWS_PER_IO)), - round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); - static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); - float2 sum_local_group = {0.f, 0.f}; - if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; - // TODO: map threads to both the CPG loop and the ROWS loop - for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { - int c = local_group_idx * CPG + local_c_loop; - if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { - for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; src_thread_tile_y += THREADS_PER_GROUP) { - int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; - channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); - sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; - sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; - } - } - } - } - static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); - for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { - sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); - sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); - } - if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - if constexpr (USE_SHARED_RED_BUFFER) { - static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); - if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); - } else { - shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; - } - } else { - *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - virtual_block_idx_y) * 2]) = sum_local_group; - } - } + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + float2 sum = {0.f, 0.f}; + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + for (int k = 0; k < GCD_VEC_CPG; k++) { + sum.x += (float)frag[j].data[i * GCD_VEC_CPG + k]; + sum.y += (float)frag[j].data[i * GCD_VEC_CPG + k] * (float)frag[j].data[i * GCD_VEC_CPG + k]; } - - virtual_cluster_sync(barrier); - - // Group sum - __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; - if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) - constexpr int THREADS_PER_GROUP = std::min(std::min(32U, - round_up_pow2(virtual_cluster_dim_y)), - round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); - static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); - float2 sum_global_group = {0.f, 0.f}; - if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - if constexpr (C_PER_BLOCK % CPG == 0) { - // Special case: no cross-virtual_cluster_dim_x reduction - float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; - for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { - float2 val; - if constexpr (USE_SHARED_RED_BUFFER) { - if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { - val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; - } else { - static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); - float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); - val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; - } - } else { - val = *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - i) * 2]); - } - buffer[i / THREADS_PER_GROUP] = val; - } - for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { - float2 val = buffer[i / THREADS_PER_GROUP]; - sum_global_group.x += val.x; - sum_global_group.y += val.y; - } - } else { - // Common case: cross-virtual_cluster_dim_x reduction - int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; - for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { - int src_virtual_block_idx_x = i % virtual_cluster_dim_x; - int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; - int src_block_group_start = src_block_channel_start / CPG; - int relative_group_idx = local_group_idx - src_block_group_start; - if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { - float2 val; - if constexpr (USE_SHARED_RED_BUFFER) { - static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); - static_assert(VIRTUAL_CLUSTER_SIZE != 1, "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); - float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); - val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; - } else { - val = *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - relative_group_idx * virtual_cluster_dim_y + - i / virtual_cluster_dim_x) * 2]); - } - sum_global_group.x += val.x; - sum_global_group.y += val.y; - } - } - } - } - if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { - // Need cluster sync after distributed shared memory access, otherwise behavior is undefined - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - cg::this_cluster().barrier_arrive(); - } - } else { - cg::this_cluster().barrier_arrive(); - } - } - static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); - for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { - sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); - sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); - } - if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); - } - __syncthreads(); + } + if constexpr (SINGLE_GROUP_PER_BLOCK) { + for (int mask = 16; mask > 0; mask >>= 1) { + sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); + sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); } - - auto get_mean_var = [&](int relative_group_idx) { - return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] : mean_var[relative_group_idx]; - }; - - if (mean_var_out) { - static_assert(MAX_NUM_GROUPS_PER_BLOCK <= BLOCK_DIM_X, "need loop"); - if (virtual_block_idx_y == 0 && threadIdx.x < MAX_NUM_GROUPS_PER_BLOCK) { - int g = block_group_start + threadIdx.x; - if (C_PER_BLOCK % CPG == 0 || g < G) { - *reinterpret_cast(&mean_var_out[(n_loop * G + g) * 2]) = get_mean_var(threadIdx.x); - } - } + static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); + if (threadIdx.x % 32 == 0) { + sum_per_channel_single_group[threadIdx.x / 32] = sum; } - - float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; - float frag_var[VEC_ELEMS / GCD_VEC_CPG]; - for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { - frag_mean[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; - frag_var[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; + } else { + sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] + [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; + } + } + __syncthreads(); + } + + U uw = *reinterpret_cast(&w[thread_channel_start]); + U ub = *reinterpret_cast(&b[thread_channel_start]); + + // Three cases for the red_buffer: + // - Block sync (VIRTUAL_CLUSTER_SIZE=1): use shared memory + // - Virtual cluster sync with HARDWARE_CLUSTER: use distributed shared memory + // - Virtual cluster sync without HARDWARE_CLUSTER: use global memory, i.e., `red_buffer` + constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; + + // Specialize for the case that each group is handled by only one block + // For common cases, blockSum produces partial sum and stores it to the red_buffer, and groupSum produces + // mean&var For the special case, blockSum produces mean&var directly + constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = + VIRTUAL_CLUSTER_SIZE == 1 && + MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented + + [[maybe_unused]] __align__(16) + __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; + + // Block sum + if constexpr (SINGLE_GROUP_PER_BLOCK) { + // block reduce + if (threadIdx.x < 32) { + float2 sum_local_group = + threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; + constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); + for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { + sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); + sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); + } + if (threadIdx.x == 0) { + if constexpr (USE_SHARED_RED_BUFFER) { + if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + shared_red_buffer[0] = compute_mean_var(sum_local_group); + } else { + shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; + } + } else { + *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + + virtual_block_idx_y) * + 2]) = sum_local_group; } - - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - U val; - if constexpr (LOAD_TWICE) { - val = *reinterpret_cast(&x[input_idx]); + } + } + } else { + // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) + constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)), + round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); + static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); + float2 sum_local_group = {0.f, 0.f}; + if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; + // TODO: map threads to both the CPG loop and the ROWS loop + for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { + int c = local_group_idx * CPG + local_c_loop; + if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { + for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; + src_thread_tile_y += THREADS_PER_GROUP) { + int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; + channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); + sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; + sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; + } + } + } + } + static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); + for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { + sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); + sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); + } + if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + if constexpr (USE_SHARED_RED_BUFFER) { + static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); + if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); + } else { + shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; + } + } else { + *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * + 2]) = sum_local_group; + } + } + } + + virtual_cluster_sync(barrier); + + // Group sum + __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; + if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) + constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)), + round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); + static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); + float2 sum_global_group = {0.f, 0.f}; + if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + if constexpr (C_PER_BLOCK % CPG == 0) { + // Special case: no cross-virtual_cluster_dim_x reduction + float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; + for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { + float2 val; + if constexpr (USE_SHARED_RED_BUFFER) { + if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { + val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } else { - val = frag[j]; - } - for (int k = 0; k < VEC_ELEMS; k++) { - float f = ((float)val.data[k] - frag_mean[k / GCD_VEC_CPG]) * rsqrtf(frag_var[k / GCD_VEC_CPG] + eps) * (float)uw.data[k] + (float)ub.data[k]; - if constexpr (SILU) f = f / (1.f + expf(-f)); - val.data[k] = f; + static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); + float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank( + shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); + val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } - *reinterpret_cast(&out[input_idx]) = val; + } else { + val = *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) * + 2]); + } + buffer[i / THREADS_PER_GROUP] = val; } - - if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - cg::this_cluster().barrier_wait(); - } + for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { + float2 val = buffer[i / THREADS_PER_GROUP]; + sum_global_group.x += val.x; + sum_global_group.y += val.y; + } + } else { + // Common case: cross-virtual_cluster_dim_x reduction + int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; + for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { + int src_virtual_block_idx_x = i % virtual_cluster_dim_x; + int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; + int src_block_group_start = src_block_channel_start / CPG; + int relative_group_idx = local_group_idx - src_block_group_start; + if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { + float2 val; + if constexpr (USE_SHARED_RED_BUFFER) { + static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); + static_assert(VIRTUAL_CLUSTER_SIZE != 1, + "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); + float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); + val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; } else { - cg::this_cluster().barrier_wait(); + val = *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) * + 2]); } + sum_global_group.x += val.x; + sum_global_group.y += val.y; + } } - - if constexpr (!PERSISTENT) { - break; + } + } + if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { + // Need cluster sync after distributed shared memory access, otherwise behavior is undefined + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + cg::this_cluster().barrier_arrive(); } - step ^= 1; + } else { + cg::this_cluster().barrier_arrive(); + } } + static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); + for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { + sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); + sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); + } + if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); + } + __syncthreads(); + } + + auto get_mean_var = [&](int relative_group_idx) { + return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] + : mean_var[relative_group_idx]; + }; + + if (mean_var_out) { + static_assert(MAX_NUM_GROUPS_PER_BLOCK <= BLOCK_DIM_X, "need loop"); + if (virtual_block_idx_y == 0 && threadIdx.x < MAX_NUM_GROUPS_PER_BLOCK) { + int g = block_group_start + threadIdx.x; + if (C_PER_BLOCK % CPG == 0 || g < G) { + *reinterpret_cast(&mean_var_out[(n_loop * G + g) * 2]) = get_mean_var(threadIdx.x); + } + } + } + + float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; + float frag_var[VEC_ELEMS / GCD_VEC_CPG]; + for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { + frag_mean[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; + frag_var[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; + } + + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + U val; + if constexpr (LOAD_TWICE) { + val = *reinterpret_cast(&x[input_idx]); + } else { + val = frag[j]; + } + for (int k = 0; k < VEC_ELEMS; k++) { + float f = ((float)val.data[k] - frag_mean[k / GCD_VEC_CPG]) * rsqrtf(frag_var[k / GCD_VEC_CPG] + eps) * + (float)uw.data[k] + + (float)ub.data[k]; + if constexpr (SILU) f = f / (1.f + expf(-f)); + val.data[k] = f; + } + *reinterpret_cast(&out[input_idx]) = val; + } + + if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + cg::this_cluster().barrier_wait(); + } + } else { + cg::this_cluster().barrier_wait(); + } + } + + if constexpr (!PERSISTENT) { + break; + } + step ^= 1; } + } } enum WgradSyncMethod { - WGRAD_ARRIVE_AND_WAIT_GRID = 0, // grid arrive after the last virtual cluster sync - WGRAD_ARRIVE_AND_WAIT_GROUP, // group arrive after the last virtual cluster sync (a group sync means synchronizing all clusters cooperating on the same groups) - WGRAD_REUSE_SUM_SYNC_GRID, // grid sync together with the last virtual cluster sync - WGRAD_REUSE_SUM_SYNC_GROUP, // group sync together with the last virtual cluster sync - WGRAD_SYNC_AT_LAST, // add a sync at the end of NC loops - WGRAD_SYNC_UNSPECIFIED, + WGRAD_ARRIVE_AND_WAIT_GRID = 0, // grid arrive after the last virtual cluster sync + WGRAD_ARRIVE_AND_WAIT_GROUP, // group arrive after the last virtual cluster sync (a group sync means synchronizing + // all clusters cooperating on the same groups) + WGRAD_REUSE_SUM_SYNC_GRID, // grid sync together with the last virtual cluster sync + WGRAD_REUSE_SUM_SYNC_GROUP, // group sync together with the last virtual cluster sync + WGRAD_SYNC_AT_LAST, // add a sync at the end of NC loops + WGRAD_SYNC_UNSPECIFIED, }; -template -__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_bwd_cuda_kernel(T *__restrict__ grad_input, T *__restrict__ grad_weight, T *__restrict__ grad_bias, T const *__restrict__ grad_output, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b, float const *__restrict__ mean_var, float eps, int64_t n, float *__restrict__ red_buffer, unsigned *__restrict__ barrier) { - // Procedure Overview - // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) - // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is used), - // write wgrad to gmem at the last loop (at each loop if not CONSTANT_C_LOOP) - // 3. Group sum: read from gmem, write mean&var to smem - // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem - // 5. Wgrad sum: read from gmem, write to gmem - - static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); - - constexpr int C = G * CPG; - static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); - static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); - static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); - static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), "inefficient configuration, please reduce C_PER_CLUSTER"); - - static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); - struct alignas(VEC_ELEMS * sizeof(T)) U { - T data[VEC_ELEMS]; - }; - - // This function computes mean_dyw and mean_xdyw. - // The function name is not changed because it has the same logic as the forward pass. - auto compute_mean_var = [&](float2 sum) { - float mean_dyw = sum.x / (HW * CPG); - float mean_xdyw = sum.y / (HW * CPG); - return float2{mean_dyw, mean_xdyw}; - }; - - static_assert(HW % ROWS_PER_BLOCK == 0, "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); - constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; - constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); - constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; - constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; - int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; - int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; - - if constexpr (CompileCondition::matches()) { - int step = 0; - constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; - if constexpr (!CONSTANT_C_LOOP) { - static_assert(wgrad_sync_method != WGRAD_ARRIVE_AND_WAIT_GROUP && wgrad_sync_method != WGRAD_REUSE_SUM_SYNC_GROUP, - "grid sync is required when each block is responsible for multiple channel ranges"); +template +__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_bwd_cuda_kernel( + T *__restrict__ grad_input, T *__restrict__ grad_weight, T *__restrict__ grad_bias, + T const *__restrict__ grad_output, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b, + float const *__restrict__ mean_var, float eps, int64_t n, float *__restrict__ red_buffer, + unsigned *__restrict__ barrier) { + // Procedure Overview + // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) + // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is + // used), + // write wgrad to gmem at the last loop (at each loop if not CONSTANT_C_LOOP) + // 3. Group sum: read from gmem, write mean&var to smem + // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem + // 5. Wgrad sum: read from gmem, write to gmem + + static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); + + constexpr int C = G * CPG; + static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); + static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); + static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); + static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), + "inefficient configuration, please reduce C_PER_CLUSTER"); + + static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); + struct alignas(VEC_ELEMS * sizeof(T)) U { + T data[VEC_ELEMS]; + }; + + // This function computes mean_dyw and mean_xdyw. + // The function name is not changed because it has the same logic as the forward pass. + auto compute_mean_var = [&](float2 sum) { + float mean_dyw = sum.x / (HW * CPG); + float mean_xdyw = sum.y / (HW * CPG); + return float2{mean_dyw, mean_xdyw}; + }; + + static_assert(HW % ROWS_PER_BLOCK == 0, + "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); + constexpr int MAX_NUM_GROUPS_PER_BLOCK = + C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; + constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); + constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; + constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; + int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; + int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; + + if constexpr (CompileCondition::matches()) { + int step = 0; + constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; + if constexpr (!CONSTANT_C_LOOP) { + static_assert(wgrad_sync_method != WGRAD_ARRIVE_AND_WAIT_GROUP && wgrad_sync_method != WGRAD_REUSE_SUM_SYNC_GROUP, + "grid sync is required when each block is responsible for multiple channel ranges"); + } + NCScheduler nc_scheduler( + n); // TODO: I don't know why the template specialization with CONSTANT_C_LOOP=true is slower. + + [[maybe_unused]] int virtual_cluster_idx_c = blockIdx.y % (C / C_PER_CLUSTER); + [[maybe_unused]] cg::grid_group::arrival_token wgrad_sync_token; + [[maybe_unused]] float dw_thread[VEC_ELEMS]; + [[maybe_unused]] float db_thread[VEC_ELEMS]; + [[maybe_unused]] __shared__ union { + float2 dwdb_block_buffer[BLOCK_DIM_X][VEC_ELEMS]; + struct { + float wgrad_buffer[BLOCK_DIM_X / 32][32]; + float bgrad_buffer[BLOCK_DIM_X / 32][32]; + } transpose_buffer; + } union_smem; + if constexpr (REQUIRES_WGRAD && CONSTANT_C_LOOP) { + for (int i = 0; i < VEC_ELEMS; i++) { + dw_thread[i] = 0.f; + db_thread[i] = 0.f; + } + } + float *red_buffer_wgrad = + &red_buffer[(2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK) * 2]; + unsigned *barrier_wgrad = barrier + NUM_VIRTUAL_CLUSTERS; + if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { + if (nc_scheduler.at_end(n)) { + static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); + if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { + wgrad_sync_token = group_barrier_arrive( + barrier_wgrad, blockIdx.x + blockIdx.y == 0); + } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { + wgrad_sync_token = + group_barrier_arrive( + barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); + } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { + wgrad_sync_token = group_barrier_arrive( + barrier_wgrad, blockIdx.x + blockIdx.y == 0); + group_barrier_wait(barrier_wgrad, wgrad_sync_token); + } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { + wgrad_sync_token = + group_barrier_arrive( + barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); + group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); + } + } + } + + while (true) { // TODO: unroll the loop + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + break; + } + } + auto [n_loop, c_loop] = nc_scheduler.get_nc(); + if constexpr (PERSISTENT) { + nc_scheduler.next(n); + } + static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); + static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, + "each block should load one or more C_PER_BLOCK at once"); + constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; + static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); + int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; + int block_group_start = block_channel_start / CPG; + int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; + U frag_x[ROWS_PER_BLOCK / ROWS_PER_IO]; + U frag_dy[ROWS_PER_BLOCK / ROWS_PER_IO]; + + constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); + + constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; + [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime( + 128 / (int)sizeof(float2), ROWS_PER_IO)]; + [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; + + float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; + float frag_var[VEC_ELEMS / GCD_VEC_CPG]; + for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { + float2 value = + *reinterpret_cast(&mean_var[(n_loop * G + (thread_channel_start + k) / CPG) * 2]); + frag_mean[k / GCD_VEC_CPG] = value.x; + frag_var[k / GCD_VEC_CPG] = value.y; + } + + U uw = *reinterpret_cast(&w[thread_channel_start]); + U ub; + if constexpr (SILU) { + ub = *reinterpret_cast(&b[thread_channel_start]); + } + if constexpr (REQUIRES_WGRAD && !CONSTANT_C_LOOP) { + for (int i = 0; i < VEC_ELEMS; i++) { + dw_thread[i] = 0.f; + db_thread[i] = 0.f; } - NCScheduler nc_scheduler(n); // TODO: I don't know why the template specialization with CONSTANT_C_LOOP=true is slower. - - [[maybe_unused]] int virtual_cluster_idx_c = blockIdx.y % (C / C_PER_CLUSTER); - [[maybe_unused]] cg::grid_group::arrival_token wgrad_sync_token; - [[maybe_unused]] float dw_thread[VEC_ELEMS]; - [[maybe_unused]] float db_thread[VEC_ELEMS]; - [[maybe_unused]] __shared__ union { - float2 dwdb_block_buffer[BLOCK_DIM_X][VEC_ELEMS]; - struct { - float wgrad_buffer[BLOCK_DIM_X / 32][32]; - float bgrad_buffer[BLOCK_DIM_X / 32][32]; - } transpose_buffer; - } union_smem; - if constexpr (REQUIRES_WGRAD && CONSTANT_C_LOOP) { - for (int i = 0; i < VEC_ELEMS; i++) { - dw_thread[i] = 0.f; - db_thread[i] = 0.f; + } + + if constexpr (LOAD_TWICE) { + float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + U ux = *reinterpret_cast(&x[input_idx]); + U udy = *reinterpret_cast(&grad_output[input_idx]); + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + float2 sum = frag_sum_per_channel[i]; + for (int k = 0; k < GCD_VEC_CPG; k++) { + float rnorm = rsqrtf(frag_var[i] + eps); + float x_norm = + ((float)ux.data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm; // TODO: store rsqrtf in mean_var + float grad_gn = udy.data[i * GCD_VEC_CPG + k]; + if constexpr (SILU) { + float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; + float s = 1.f / (1.f + expf(-x_gn)); + grad_gn *= s * (1.f + x_gn * (1.f - s)); + } + sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; + sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); + if constexpr (REQUIRES_WGRAD) { + dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; + db_thread[i * GCD_VEC_CPG + k] += grad_gn; + } } + frag_sum_per_channel[i] = sum; + } } - float *red_buffer_wgrad = &red_buffer[(2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK) * 2]; - unsigned *barrier_wgrad = barrier + NUM_VIRTUAL_CLUSTERS; - if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { - if (nc_scheduler.at_end(n)) { - static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); - if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { - wgrad_sync_token = group_barrier_arrive(barrier_wgrad, blockIdx.x + blockIdx.y == 0); - } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { - wgrad_sync_token = group_barrier_arrive(barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); - } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { - wgrad_sync_token = group_barrier_arrive(barrier_wgrad, blockIdx.x + blockIdx.y == 0); - group_barrier_wait(barrier_wgrad, wgrad_sync_token); - } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { - wgrad_sync_token = group_barrier_arrive(barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); - group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); - } + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + if constexpr (SINGLE_GROUP_PER_BLOCK) { + for (int mask = 16; mask > 0; mask >>= 1) { + frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); + frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); + } + static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); + if (threadIdx.x % 32 == 0) { + sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; } + } else { + sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] + [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; + } + } + __syncthreads(); + } else { + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + frag_x[j] = *reinterpret_cast(&x[input_idx]); + frag_dy[j] = *reinterpret_cast(&grad_output[input_idx]); } - while (true) { // TODO: unroll the loop - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - break; - } + for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { + float2 sum = {0.f, 0.f}; + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + for (int k = 0; k < GCD_VEC_CPG; k++) { + float rnorm = rsqrtf(frag_var[i] + eps); + float x_norm = ((float)frag_x[j].data[i * GCD_VEC_CPG + k] - frag_mean[i]) * + rnorm; // TODO: store rsqrtf in mean_var + float grad_gn = frag_dy[j].data[i * GCD_VEC_CPG + k]; + if constexpr (SILU) { + float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; + float s = 1.f / (1.f + expf(-x_gn)); + grad_gn *= s * (1.f + x_gn * (1.f - s)); + } + sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; + sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); + if constexpr (REQUIRES_WGRAD) { + dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; + db_thread[i * GCD_VEC_CPG + k] += grad_gn; + } } - auto [n_loop, c_loop] = nc_scheduler.get_nc(); - if constexpr (PERSISTENT) { - nc_scheduler.next(n); + } + if constexpr (SINGLE_GROUP_PER_BLOCK) { + for (int mask = 16; mask > 0; mask >>= 1) { + sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); + sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); } - static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); - static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, "each block should load one or more C_PER_BLOCK at once"); - constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; - static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); - int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; - int block_group_start = block_channel_start / CPG; - int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; - U frag_x[ROWS_PER_BLOCK / ROWS_PER_IO]; - U frag_dy[ROWS_PER_BLOCK / ROWS_PER_IO]; - - constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); - - constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; - [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(128 / (int)sizeof(float2), ROWS_PER_IO)]; - [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; - - float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; - float frag_var[VEC_ELEMS / GCD_VEC_CPG]; - for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { - float2 value = *reinterpret_cast(&mean_var[(n_loop * G + (thread_channel_start + k) / CPG) * 2]); - frag_mean[k / GCD_VEC_CPG] = value.x; - frag_var[k / GCD_VEC_CPG] = value.y; + static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); + if (threadIdx.x % 32 == 0) { + sum_per_channel_single_group[threadIdx.x / 32] = sum; } - - U uw = *reinterpret_cast(&w[thread_channel_start]); - U ub; - if constexpr (SILU) { - ub = *reinterpret_cast(&b[thread_channel_start]); + } else { + sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] + [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; + } + } + __syncthreads(); + } + + if ((CONSTANT_C_LOOP && nc_scheduler.at_end(n)) || !CONSTANT_C_LOOP) { + constexpr int NT_C = max_divisor(C_PER_BLOCK, BLOCK_DIM_X); // Number of threads on the C axis + constexpr int NT_R = + 1; // std::min(32, (int)round_down_pow2(BLOCK_DIM_X / NT_C)); // Number of threads on the ROWS axis + // TODO: swizzle for NT_R + for (int i = 0; i < VEC_ELEMS; i++) { + union_smem.dwdb_block_buffer[threadIdx.x][i ^ ((threadIdx.x / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))] = + float2{dw_thread[i], db_thread[i]}; + } + __syncthreads(); + static_assert(NT_C * NT_R <= BLOCK_DIM_X, "not enough threads"); + static_assert(C_PER_BLOCK % NT_C == 0, "need to loop once more and check c < C_PER_BLOCK"); + for (int i = 0; i < C_PER_BLOCK / NT_C; i++) { + int c = i * NT_C + threadIdx.x / NT_R; + float dw_block = 0.f; + float db_block = 0.f; + if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { + for (int j = threadIdx.x % NT_R; j < ROWS_PER_IO; j += NT_R) { + int src_thread = j * (C_PER_BLOCK / VEC_ELEMS) + c / VEC_ELEMS; + float2 val = union_smem.dwdb_block_buffer[src_thread][(c % VEC_ELEMS) ^ ((src_thread / (16 / VEC_ELEMS)) & + (VEC_ELEMS - 1))]; + dw_block += val.x; + db_block += val.y; } - if constexpr (REQUIRES_WGRAD && !CONSTANT_C_LOOP) { - for (int i = 0; i < VEC_ELEMS; i++) { - dw_thread[i] = 0.f; - db_thread[i] = 0.f; - } + } + static_assert(32 % NT_R == 0, "cannot shuffle"); + for (int mask = NT_R / 2; mask > 0; mask >>= 1) { + dw_block += __shfl_xor_sync(FINAL_MASK, dw_block, mask, 32); + db_block += __shfl_xor_sync(FINAL_MASK, db_block, mask, 32); + } + if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { + if (threadIdx.x % NT_R == 0) { + if constexpr (CONSTANT_C_LOOP) { + *reinterpret_cast( + &red_buffer_wgrad + [((blockIdx.y / (C / C_PER_CLUSTER) * virtual_cluster_dim_y + virtual_block_idx_y) * C + + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * + 2]) = float2{dw_block, db_block}; + } else { + *reinterpret_cast( + &red_buffer_wgrad[((n_loop * virtual_cluster_dim_y + virtual_block_idx_y) * C + + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * + 2]) = float2{dw_block, db_block}; + } } - - if constexpr (LOAD_TWICE) { - float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - U ux = *reinterpret_cast(&x[input_idx]); - U udy = *reinterpret_cast(&grad_output[input_idx]); - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - float2 sum = frag_sum_per_channel[i]; - for (int k = 0; k < GCD_VEC_CPG; k++) { - float rnorm = rsqrtf(frag_var[i] + eps); - float x_norm = ((float)ux.data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm; // TODO: store rsqrtf in mean_var - float grad_gn = udy.data[i * GCD_VEC_CPG + k]; - if constexpr (SILU) { - float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; - float s = 1.f / (1.f + expf(-x_gn)); - grad_gn *= s * (1.f + x_gn * (1.f - s)); - } - sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; - sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); - if constexpr (REQUIRES_WGRAD) { - dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; - db_thread[i * GCD_VEC_CPG + k] += grad_gn; - } - } - frag_sum_per_channel[i] = sum; - } - } - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - if constexpr (SINGLE_GROUP_PER_BLOCK) { - for (int mask = 16; mask > 0; mask >>= 1) { - frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); - frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); - } - static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); - if (threadIdx.x % 32 == 0) { - sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; - } - } else { - sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)][threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; - } - } - __syncthreads(); + } + } + } + + constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; + constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = + VIRTUAL_CLUSTER_SIZE == 1 && + MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented + [[maybe_unused]] __align__(16) + __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; + + // Block sum + if constexpr (SINGLE_GROUP_PER_BLOCK) { + // block reduce + if (threadIdx.x < 32) { + float2 sum_local_group = + threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; + constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); + for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { + sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); + sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); + } + if (threadIdx.x == 0) { + if constexpr (USE_SHARED_RED_BUFFER) { + if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + shared_red_buffer[0] = compute_mean_var(sum_local_group); + } else { + shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; + } } else { - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - frag_x[j] = *reinterpret_cast(&x[input_idx]); - frag_dy[j] = *reinterpret_cast(&grad_output[input_idx]); - } - - for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { - float2 sum = {0.f, 0.f}; - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - for (int k = 0; k < GCD_VEC_CPG; k++) { - float rnorm = rsqrtf(frag_var[i] + eps); - float x_norm = ((float)frag_x[j].data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm; // TODO: store rsqrtf in mean_var - float grad_gn = frag_dy[j].data[i * GCD_VEC_CPG + k]; - if constexpr (SILU) { - float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; - float s = 1.f / (1.f + expf(-x_gn)); - grad_gn *= s * (1.f + x_gn * (1.f - s)); - } - sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; - sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); - if constexpr (REQUIRES_WGRAD) { - dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; - db_thread[i * GCD_VEC_CPG + k] += grad_gn; - } - } - } - if constexpr (SINGLE_GROUP_PER_BLOCK) { - for (int mask = 16; mask > 0; mask >>= 1) { - sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); - sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); - } - static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); - if (threadIdx.x % 32 == 0) { - sum_per_channel_single_group[threadIdx.x / 32] = sum; - } - } else { - sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)][threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; - } - } - __syncthreads(); + *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + + virtual_block_idx_y) * + 2]) = sum_local_group; } - - if ((CONSTANT_C_LOOP && nc_scheduler.at_end(n)) || !CONSTANT_C_LOOP) { - constexpr int NT_C = max_divisor(C_PER_BLOCK, BLOCK_DIM_X); // Number of threads on the C axis - constexpr int NT_R = 1; // std::min(32, (int)round_down_pow2(BLOCK_DIM_X / NT_C)); // Number of threads on the ROWS axis - // TODO: swizzle for NT_R - for (int i = 0; i < VEC_ELEMS; i++) { - union_smem.dwdb_block_buffer[threadIdx.x][i ^ ((threadIdx.x / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))] = float2{dw_thread[i], db_thread[i]}; - } - __syncthreads(); - static_assert(NT_C * NT_R <= BLOCK_DIM_X, "not enough threads"); - static_assert(C_PER_BLOCK % NT_C == 0, "need to loop once more and check c < C_PER_BLOCK"); - for (int i = 0; i < C_PER_BLOCK / NT_C; i++) { - int c = i * NT_C + threadIdx.x / NT_R; - float dw_block = 0.f; - float db_block = 0.f; - if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { - for (int j = threadIdx.x % NT_R; j < ROWS_PER_IO; j += NT_R) { - int src_thread = j * (C_PER_BLOCK / VEC_ELEMS) + c / VEC_ELEMS; - float2 val = union_smem.dwdb_block_buffer[src_thread][(c % VEC_ELEMS) ^ ((src_thread / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))]; - dw_block += val.x; - db_block += val.y; - } - } - static_assert(32 % NT_R == 0, "cannot shuffle"); - for (int mask = NT_R / 2; mask > 0; mask >>= 1) { - dw_block += __shfl_xor_sync(FINAL_MASK, dw_block, mask, 32); - db_block += __shfl_xor_sync(FINAL_MASK, db_block, mask, 32); - } - if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { - if (threadIdx.x % NT_R == 0) { - if constexpr (CONSTANT_C_LOOP) { - *reinterpret_cast(&red_buffer_wgrad[((blockIdx.y / (C / C_PER_CLUSTER) * virtual_cluster_dim_y + virtual_block_idx_y) * C + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * 2]) = float2{dw_block, db_block}; - } else { - *reinterpret_cast(&red_buffer_wgrad[((n_loop * virtual_cluster_dim_y + virtual_block_idx_y) * C + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * 2]) = float2{dw_block, db_block}; - } - } - } - } + } + } + } else { + // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) + constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)), + round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); + static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); + float2 sum_local_group = {0.f, 0.f}; + if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; + // TODO: map threads to both the CPG loop and the ROWS loop + for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { + int c = local_group_idx * CPG + local_c_loop; + if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { + for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; + src_thread_tile_y += THREADS_PER_GROUP) { + int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; + channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); + sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; + sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; + } } - - constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; - constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = VIRTUAL_CLUSTER_SIZE == 1 && MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented - [[maybe_unused]] __align__(16) __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; - - // Block sum - if constexpr (SINGLE_GROUP_PER_BLOCK) { - // block reduce - if (threadIdx.x < 32) { - float2 sum_local_group = threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; - constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); - for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { - sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); - sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); - } - if (threadIdx.x == 0) { - if constexpr (USE_SHARED_RED_BUFFER) { - if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - shared_red_buffer[0] = compute_mean_var(sum_local_group); - } else { - shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; - } - } else { - *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - virtual_block_idx_y) * 2]) = sum_local_group; - } - } - } + } + } + static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); + for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { + sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); + sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); + } + if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + if constexpr (USE_SHARED_RED_BUFFER) { + static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); + if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); } else { - // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) - constexpr int THREADS_PER_GROUP = std::min(std::min(32U, - round_up_pow2(ROWS_PER_IO)), - round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); - static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); - float2 sum_local_group = {0.f, 0.f}; - if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; - // TODO: map threads to both the CPG loop and the ROWS loop - for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { - int c = local_group_idx * CPG + local_c_loop; - if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { - for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; src_thread_tile_y += THREADS_PER_GROUP) { - int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; - channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); - sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; - sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; - } - } - } - } - static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); - for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { - sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); - sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); - } - if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - if constexpr (USE_SHARED_RED_BUFFER) { - static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); - if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); - } else { - shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; - } - } else { - *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - virtual_block_idx_y) * 2]) = sum_local_group; - } - } + shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; } + } else { + *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * + 2]) = sum_local_group; + } + } + } - if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { - if (nc_scheduler.at_end(n)) { - static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); - if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { - virtual_cluster_sync(barrier); - wgrad_sync_token = group_barrier_arrive(barrier_wgrad, blockIdx.x + blockIdx.y == 0); - } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { - virtual_cluster_sync(barrier); - wgrad_sync_token = group_barrier_arrive(barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); - } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { - static_assert(!HARDWARE_CLUSTER, "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GRID instead."); - wgrad_sync_token = group_barrier_arrive(barrier_wgrad, blockIdx.x + blockIdx.y == 0); - group_barrier_wait(barrier_wgrad, wgrad_sync_token); - } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { - static_assert(!HARDWARE_CLUSTER, "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GROUP instead."); - wgrad_sync_token = group_barrier_arrive(barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); - group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); - } + if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { + if (nc_scheduler.at_end(n)) { + static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); + if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { + virtual_cluster_sync(barrier); + wgrad_sync_token = group_barrier_arrive( + barrier_wgrad, blockIdx.x + blockIdx.y == 0); + } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { + virtual_cluster_sync(barrier); + wgrad_sync_token = + group_barrier_arrive( + barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); + } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { + static_assert(!HARDWARE_CLUSTER, + "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GRID instead."); + wgrad_sync_token = group_barrier_arrive( + barrier_wgrad, blockIdx.x + blockIdx.y == 0); + group_barrier_wait(barrier_wgrad, wgrad_sync_token); + } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { + static_assert(!HARDWARE_CLUSTER, + "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GROUP instead."); + wgrad_sync_token = + group_barrier_arrive( + barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); + group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); + } + } else { + virtual_cluster_sync(barrier); + } + } else { + virtual_cluster_sync(barrier); + } + + // Group sum + __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; + if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { + // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) + constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)), + round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); + static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); + float2 sum_global_group = {0.f, 0.f}; + if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + if constexpr (C_PER_BLOCK % CPG == 0) { + // Special case: no cross-virtual_cluster_dim_x reduction + float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; + for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { + float2 val; + if constexpr (USE_SHARED_RED_BUFFER) { + if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { + val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } else { - virtual_cluster_sync(barrier); + static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); + float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank( + shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); + val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } - } else { - virtual_cluster_sync(barrier); + } else { + val = *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) * + 2]); + } + buffer[i / THREADS_PER_GROUP] = val; } - - // Group sum - __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; - if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { - // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) - constexpr int THREADS_PER_GROUP = std::min(std::min(32U, - round_up_pow2(virtual_cluster_dim_y)), - round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); - static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); - float2 sum_global_group = {0.f, 0.f}; - if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - if constexpr (C_PER_BLOCK % CPG == 0) { - // Special case: no cross-virtual_cluster_dim_x reduction - float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; - for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { - float2 val; - if constexpr (USE_SHARED_RED_BUFFER) { - if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { - val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; - } else { - static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); - float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); - val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; - } - } else { - val = *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + - i) * 2]); - } - buffer[i / THREADS_PER_GROUP] = val; - } - for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { - float2 val = buffer[i / THREADS_PER_GROUP]; - sum_global_group.x += val.x; - sum_global_group.y += val.y; - } - } else { - // Common case: cross-virtual_cluster_dim_x reduction - int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; - for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { - int src_virtual_block_idx_x = i % virtual_cluster_dim_x; - int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; - int src_block_group_start = src_block_channel_start / CPG; - int relative_group_idx = local_group_idx - src_block_group_start; - if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { - float2 val; - if constexpr (USE_SHARED_RED_BUFFER) { - static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); - static_assert(VIRTUAL_CLUSTER_SIZE != 1, "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); - float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); - val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; - } else { - val = *reinterpret_cast(&red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + - src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + - relative_group_idx * virtual_cluster_dim_y + - i / virtual_cluster_dim_x) * 2]); - } - sum_global_group.x += val.x; - sum_global_group.y += val.y; - } - } - } - } - if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { - // Need cluster sync after distributed shared memory access, otherwise behavior is undefined - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - cg::this_cluster().barrier_arrive(); - } - } else { - cg::this_cluster().barrier_arrive(); - } - } - static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); - for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { - sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); - sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); - } - if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { - mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); - } - __syncthreads(); - } - - auto get_mean_var = [&](int relative_group_idx) { - return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] : mean_var[relative_group_idx]; - }; - - float frag_dyw[VEC_ELEMS / GCD_VEC_CPG]; - float frag_xdyw[VEC_ELEMS / GCD_VEC_CPG]; - for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { - frag_dyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; - frag_xdyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; + for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { + float2 val = buffer[i / THREADS_PER_GROUP]; + sum_global_group.x += val.x; + sum_global_group.y += val.y; } - - for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { - int64_t input_idx = n_loop * HW * C + - (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + - thread_channel_start; - U ux; - U udy; - if constexpr (LOAD_TWICE) { - ux = *reinterpret_cast(&x[input_idx]); - udy = *reinterpret_cast(&grad_output[input_idx]); + } else { + // Common case: cross-virtual_cluster_dim_x reduction + int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; + for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { + int src_virtual_block_idx_x = i % virtual_cluster_dim_x; + int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; + int src_block_group_start = src_block_channel_start / CPG; + int relative_group_idx = local_group_idx - src_block_group_start; + if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { + float2 val; + if constexpr (USE_SHARED_RED_BUFFER) { + static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); + static_assert(VIRTUAL_CLUSTER_SIZE != 1, + "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); + float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); + val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; } else { - ux = frag_x[j]; - udy = frag_dy[j]; + val = *reinterpret_cast( + &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + + src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + + relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) * + 2]); } - U val; - for (int k = 0; k < VEC_ELEMS; k++) { - float rnorm = rsqrtf(frag_var[k / GCD_VEC_CPG] + eps); - float x_norm = ((float)ux.data[k] - frag_mean[k / GCD_VEC_CPG]) * rnorm; // TODO: store rsqrtf in mean_var - float grad_gn = udy.data[k]; - if constexpr (SILU) { - float x_gn = x_norm * (float)uw.data[k] + (float)ub.data[k]; - float s = 1.f / (1.f + expf(-x_gn)); - grad_gn *= s * (1.f + x_gn * (1.f - s)); - } - val.data[k] = (grad_gn * (float)uw.data[k] - frag_dyw[k / GCD_VEC_CPG] - frag_xdyw[k / GCD_VEC_CPG] * x_norm) * rnorm; - } - *reinterpret_cast(&grad_input[input_idx]) = val; + sum_global_group.x += val.x; + sum_global_group.y += val.y; + } } - - if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { - if constexpr (PERSISTENT) { - if (nc_scheduler.at_end(n)) { - cg::this_cluster().barrier_wait(); - } - } else { - cg::this_cluster().barrier_wait(); - } + } + } + if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { + // Need cluster sync after distributed shared memory access, otherwise behavior is undefined + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + cg::this_cluster().barrier_arrive(); } + } else { + cg::this_cluster().barrier_arrive(); + } + } + static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); + for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { + sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); + sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); + } + if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { + mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); + } + __syncthreads(); + } + + auto get_mean_var = [&](int relative_group_idx) { + return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] + : mean_var[relative_group_idx]; + }; + + float frag_dyw[VEC_ELEMS / GCD_VEC_CPG]; + float frag_xdyw[VEC_ELEMS / GCD_VEC_CPG]; + for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { + frag_dyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; + frag_xdyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; + } + + for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { + int64_t input_idx = + n_loop * HW * C + + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + + thread_channel_start; + U ux; + U udy; + if constexpr (LOAD_TWICE) { + ux = *reinterpret_cast(&x[input_idx]); + udy = *reinterpret_cast(&grad_output[input_idx]); + } else { + ux = frag_x[j]; + udy = frag_dy[j]; + } + U val; + for (int k = 0; k < VEC_ELEMS; k++) { + float rnorm = rsqrtf(frag_var[k / GCD_VEC_CPG] + eps); + float x_norm = ((float)ux.data[k] - frag_mean[k / GCD_VEC_CPG]) * rnorm; // TODO: store rsqrtf in mean_var + float grad_gn = udy.data[k]; + if constexpr (SILU) { + float x_gn = x_norm * (float)uw.data[k] + (float)ub.data[k]; + float s = 1.f / (1.f + expf(-x_gn)); + grad_gn *= s * (1.f + x_gn * (1.f - s)); + } + val.data[k] = + (grad_gn * (float)uw.data[k] - frag_dyw[k / GCD_VEC_CPG] - frag_xdyw[k / GCD_VEC_CPG] * x_norm) * rnorm; + } + *reinterpret_cast(&grad_input[input_idx]) = val; + } - if constexpr (!PERSISTENT) { - break; - } - step ^= 1; + if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { + if constexpr (PERSISTENT) { + if (nc_scheduler.at_end(n)) { + cg::this_cluster().barrier_wait(); + } + } else { + cg::this_cluster().barrier_wait(); } + } - // Wgrad sum - if constexpr (REQUIRES_WGRAD) { - static_assert(PERSISTENT, "cannot reduce wgrad"); - static_assert(C % 32 == 0, "cannot reduce wgrad"); - if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { - group_barrier_wait(barrier_wgrad, wgrad_sync_token); - } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { - group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); - } else if constexpr (wgrad_sync_method == WGRAD_SYNC_AT_LAST) { - cg::this_grid().sync(); - } + if constexpr (!PERSISTENT) { + break; + } + step ^= 1; + } - // If group sync, map blocks that are responsible for the same range of channels to these channels (named "split channels"); - // otherwise, map all blocks to all channels. - constexpr bool split_channels = wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP; - - for (int c = split_channels ? - virtual_cluster_idx_c * C_PER_CLUSTER + 32 * (blockIdx.y / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE + blockIdx.x) : - 32 * (blockIdx.y * VIRTUAL_CLUSTER_SIZE + blockIdx.x); - split_channels ? - c < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER : - c < C; - c += split_channels ? - 32 * (NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE) : - 32 * (NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE)) { - int64_t rows = (CONSTANT_C_LOOP ? std::min(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) : n) * virtual_cluster_dim_y; - float sum_wgrad = 0.f; - float sum_bgrad = 0.f; - if ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + threadIdx.x % 32 < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || - (!split_channels && (C % 32 == 0 || c + threadIdx.x % 32 < C))) { - for (int64_t i = threadIdx.x / 32; i < rows; i += BLOCK_DIM_X / 32) { - float2 val = *reinterpret_cast(&red_buffer_wgrad[(i * C + c + threadIdx.x % 32) * 2]); - sum_wgrad += val.x; - sum_bgrad += val.y; - } - } - constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); - union_smem.transpose_buffer.wgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = sum_wgrad; - union_smem.transpose_buffer.bgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = sum_bgrad; - __syncthreads(); - for (int i = threadIdx.x / warp_num_pow2; - i < 32 && ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + i < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || - (!split_channels && (C % 32 == 0 || c + i < C))); - i += BLOCK_DIM_X / warp_num_pow2) { - int j = threadIdx.x % warp_num_pow2; - float sum_wgrad = j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.wgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; - float sum_bgrad = j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.bgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; - for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { - sum_wgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_wgrad, mask, warp_num_pow2); - sum_bgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_bgrad, mask, warp_num_pow2); - } - if (j == 0) { - grad_weight[c + i] = sum_wgrad; - grad_bias[c + i] = sum_bgrad; - } - } - __syncthreads(); - } + // Wgrad sum + if constexpr (REQUIRES_WGRAD) { + static_assert(PERSISTENT, "cannot reduce wgrad"); + static_assert(C % 32 == 0, "cannot reduce wgrad"); + if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { + group_barrier_wait(barrier_wgrad, wgrad_sync_token); + } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { + group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); + } else if constexpr (wgrad_sync_method == WGRAD_SYNC_AT_LAST) { + cg::this_grid().sync(); + } + + // If group sync, map blocks that are responsible for the same range of channels to these channels (named "split + // channels"); otherwise, map all blocks to all channels. + constexpr bool split_channels = + wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP; + + for (int c = split_channels ? virtual_cluster_idx_c * C_PER_CLUSTER + + 32 * (blockIdx.y / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE + blockIdx.x) + : 32 * (blockIdx.y * VIRTUAL_CLUSTER_SIZE + blockIdx.x); + split_channels ? c < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER : c < C; + c += split_channels ? 32 * (NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE) + : 32 * (NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE)) { + int64_t rows = (CONSTANT_C_LOOP ? std::min(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) : n) * + virtual_cluster_dim_y; + float sum_wgrad = 0.f; + float sum_bgrad = 0.f; + if ((split_channels && + (C_PER_CLUSTER % 32 == 0 || c + threadIdx.x % 32 < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || + (!split_channels && (C % 32 == 0 || c + threadIdx.x % 32 < C))) { + for (int64_t i = threadIdx.x / 32; i < rows; i += BLOCK_DIM_X / 32) { + float2 val = *reinterpret_cast(&red_buffer_wgrad[(i * C + c + threadIdx.x % 32) * 2]); + sum_wgrad += val.x; + sum_bgrad += val.y; + } } + constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); + union_smem.transpose_buffer + .wgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = + sum_wgrad; + union_smem.transpose_buffer + .bgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = + sum_bgrad; + __syncthreads(); + for (int i = threadIdx.x / warp_num_pow2; + i < 32 && + ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + i < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || + (!split_channels && (C % 32 == 0 || c + i < C))); + i += BLOCK_DIM_X / warp_num_pow2) { + int j = threadIdx.x % warp_num_pow2; + float sum_wgrad = + j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.wgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; + float sum_bgrad = + j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.bgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; + for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { + sum_wgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_wgrad, mask, warp_num_pow2); + sum_bgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_bgrad, mask, warp_num_pow2); + } + if (j == 0) { + grad_weight[c + i] = sum_wgrad; + grad_bias[c + i] = sum_bgrad; + } + } + __syncthreads(); + } } + } } } // namespace group_norm_v2 diff --git a/apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp b/apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp index 20f9f327a..a2b96908b 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp +++ b/apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp @@ -1,19 +1,62 @@ #pragma once -#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] { \ - if (hw == 64 && c == 1280) { constexpr int HW = 64, C = 1280; return __VA_ARGS__(); } \ - if (hw == 64 && c == 2560) { constexpr int HW = 64, C = 2560; return __VA_ARGS__(); } \ - if (hw == 256 && c == 640) { constexpr int HW = 256, C = 640; return __VA_ARGS__(); } \ - if (hw == 256 && c == 1280) { constexpr int HW = 256, C = 1280; return __VA_ARGS__(); } \ - if (hw == 256 && c == 1920) { constexpr int HW = 256, C = 1920; return __VA_ARGS__(); } \ - if (hw == 256 && c == 2560) { constexpr int HW = 256, C = 2560; return __VA_ARGS__(); } \ - if (hw == 1024 && c == 320) { constexpr int HW = 1024, C = 320; return __VA_ARGS__(); } \ - if (hw == 1024 && c == 640) { constexpr int HW = 1024, C = 640; return __VA_ARGS__(); } \ - if (hw == 1024 && c == 960) { constexpr int HW = 1024, C = 960; return __VA_ARGS__(); } \ - if (hw == 1024 && c == 1280) { constexpr int HW = 1024, C = 1280; return __VA_ARGS__(); } \ - if (hw == 1024 && c == 1920) { constexpr int HW = 1024, C = 1920; return __VA_ARGS__(); } \ - if (hw == 4096 && c == 320) { constexpr int HW = 4096, C = 320; return __VA_ARGS__(); } \ - if (hw == 4096 && c == 640) { constexpr int HW = 4096, C = 640; return __VA_ARGS__(); } \ - if (hw == 4096 && c == 960) { constexpr int HW = 4096, C = 960; return __VA_ARGS__(); } \ +#define DISPATCH_HW_C(hw, c, HW, C, ...) \ + [&] { \ + if (hw == 64 && c == 1280) { \ + constexpr int HW = 64, C = 1280; \ + return __VA_ARGS__(); \ + } \ + if (hw == 64 && c == 2560) { \ + constexpr int HW = 64, C = 2560; \ + return __VA_ARGS__(); \ + } \ + if (hw == 256 && c == 640) { \ + constexpr int HW = 256, C = 640; \ + return __VA_ARGS__(); \ + } \ + if (hw == 256 && c == 1280) { \ + constexpr int HW = 256, C = 1280; \ + return __VA_ARGS__(); \ + } \ + if (hw == 256 && c == 1920) { \ + constexpr int HW = 256, C = 1920; \ + return __VA_ARGS__(); \ + } \ + if (hw == 256 && c == 2560) { \ + constexpr int HW = 256, C = 2560; \ + return __VA_ARGS__(); \ + } \ + if (hw == 1024 && c == 320) { \ + constexpr int HW = 1024, C = 320; \ + return __VA_ARGS__(); \ + } \ + if (hw == 1024 && c == 640) { \ + constexpr int HW = 1024, C = 640; \ + return __VA_ARGS__(); \ + } \ + if (hw == 1024 && c == 960) { \ + constexpr int HW = 1024, C = 960; \ + return __VA_ARGS__(); \ + } \ + if (hw == 1024 && c == 1280) { \ + constexpr int HW = 1024, C = 1280; \ + return __VA_ARGS__(); \ + } \ + if (hw == 1024 && c == 1920) { \ + constexpr int HW = 1024, C = 1920; \ + return __VA_ARGS__(); \ + } \ + if (hw == 4096 && c == 320) { \ + constexpr int HW = 4096, C = 320; \ + return __VA_ARGS__(); \ + } \ + if (hw == 4096 && c == 640) { \ + constexpr int HW = 4096, C = 640; \ + return __VA_ARGS__(); \ + } \ + if (hw == 4096 && c == 960) { \ + constexpr int HW = 4096, C = 960; \ + return __VA_ARGS__(); \ + } \ throw std::invalid_argument("DISPATCH_HW_C " + std::to_string(hw) + " " + std::to_string(c)); \ - }() + }() diff --git a/apex/contrib/csrc/group_norm_v2/gn_utils.cpp b/apex/contrib/csrc/group_norm_v2/gn_utils.cpp index 019bf29a7..22550fb0f 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_utils.cpp +++ b/apex/contrib/csrc/group_norm_v2/gn_utils.cpp @@ -3,21 +3,20 @@ #include #include - namespace group_norm_v2 { cudaDeviceProp const &get_device_prop(int device_id) { - static std::vector device_props; - static std::once_flag flag; - std::call_once(flag, [&] { - int count; - CUDA_CHECK(cudaGetDeviceCount(&count)); - device_props.resize(count); - for (int i = 0; i < count; i++) { - CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i)); - } - }); - return device_props.at(device_id); + static std::vector device_props; + static std::once_flag flag; + std::call_once(flag, [&] { + int count; + CUDA_CHECK(cudaGetDeviceCount(&count)); + device_props.resize(count); + for (int i = 0; i < count; i++) { + CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i)); + } + }); + return device_props.at(device_id); } } // namespace group_norm_v2 diff --git a/apex/contrib/csrc/group_norm_v2/gn_utils.hpp b/apex/contrib/csrc/group_norm_v2/gn_utils.hpp index fb81c608f..33e83b5af 100644 --- a/apex/contrib/csrc/group_norm_v2/gn_utils.hpp +++ b/apex/contrib/csrc/group_norm_v2/gn_utils.hpp @@ -1,34 +1,41 @@ #pragma once +#include + #include #include #include -#include - #include "gn.hpp" - // Definition of CUDA_CHECK macro -#define CUDA_CHECK(call) \ -do { \ - cudaError_t err_ = call; \ - if (err_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \ - __FILE__, __LINE__, err_, cudaGetErrorString(err_), #call); \ - exit(EXIT_FAILURE); \ - } \ -} while (0) - - -#define GN_CUDA_HOST_PARAMS(T) T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only - -#define GN_BWD_CUDA_HOST_PARAMS(T) T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only - -#define GN_CUDA_HOST_ARGS out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only - -#define GN_BWD_CUDA_HOST_ARGS grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only - +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err_ = call; \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, __LINE__, err_, cudaGetErrorString(err_), \ + #call); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define GN_CUDA_HOST_PARAMS(T) \ + T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, \ + float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, \ + Meta *meta_ptr, bool meta_only + +#define GN_BWD_CUDA_HOST_PARAMS(T) \ + T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, \ + bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, \ + int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only + +#define GN_CUDA_HOST_ARGS \ + out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, \ + stream, device_id, meta_ptr, meta_only + +#define GN_BWD_CUDA_HOST_ARGS \ + grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, \ + channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only namespace group_norm_v2 { @@ -36,12 +43,12 @@ cudaDeviceProp const &get_device_prop(int device_id); #ifdef __CUDA_ARCH__ -template +template __host__ __device__ inline int print_rank_0(char const *fmt, Ts &&...args) { - if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) { - return printf(fmt, std::forward(args)...); - } - return 0; + if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) { + return printf(fmt, std::forward(args)...); + } + return 0; } #endif diff --git a/apex/contrib/csrc/groupbn/batch_norm.cu b/apex/contrib/csrc/groupbn/batch_norm.cu index d6744a8e6..8ff278d43 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.cu +++ b/apex/contrib/csrc/groupbn/batch_norm.cu @@ -1,26 +1,21 @@ #include #include #include +#include #include "batch_norm.h" -#include +#define cudaCheckErrors(msg) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ + fprintf(stderr, "*** FAILED - ABORTING\n"); \ + exit(1); \ + } \ + } while (0) -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -static size_t round_up_to_multiple(size_t x, int multiple) { - return ((x + multiple - 1) / multiple) * multiple; -} +static size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; } struct Workspace { Workspace(size_t size) : size(size), data(NULL) { @@ -39,28 +34,13 @@ struct Workspace { }; // Return {y} -at::Tensor nhwc_bn_fwd_train( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - +at::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, + const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, + const bool coop) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); @@ -74,7 +54,7 @@ at::Tensor nhwc_bn_fwd_train( at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); + NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -82,10 +62,7 @@ at::Tensor nhwc_bn_fwd_train( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - nullptr, - y.data_ptr(), - nullptr); + bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); @@ -108,18 +85,18 @@ at::Tensor nhwc_bn_fwd_train( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } @@ -131,18 +108,10 @@ at::Tensor nhwc_bn_fwd_train( return y; } -at::Tensor nhwc_bn_fwd_eval( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon, - const bool fuse_relu) { - +at::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon, + const bool fuse_relu) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); @@ -152,7 +121,7 @@ at::Tensor nhwc_bn_fwd_eval( at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); + NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -160,10 +129,7 @@ at::Tensor nhwc_bn_fwd_eval( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - nullptr, - y.data_ptr(), - nullptr); + bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); @@ -186,18 +152,18 @@ at::Tensor nhwc_bn_fwd_eval( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(nullptr); workspace.push_back(nullptr); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } @@ -207,31 +173,16 @@ at::Tensor nhwc_bn_fwd_eval( bn->fwdInference(stream, fuse_relu); return y; - } -std::vector nhwc_bn_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { +std::vector nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, + const float momentum, const float epsilon, const bool fuse_relu, void* my_data, + void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, + const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, + const bool coop) { // shape const int N = x.size(0); const int H = x.size(1); @@ -251,7 +202,7 @@ std::vector nhwc_bn_bwd( bias_grad = at::empty_like(bias); // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); + NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -259,12 +210,10 @@ std::vector nhwc_bn_bwd( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - x_grad.data_ptr(), - nullptr, - dy.data_ptr()); + bn->setInputOutputPointers(x.data_ptr(), x_grad.data_ptr(), nullptr, dy.data_ptr()); - bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {scale_grad.data_ptr(), bias_grad.data_ptr()}); + bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, + {scale_grad.data_ptr(), bias_grad.data_ptr()}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) @@ -285,42 +234,41 @@ std::vector nhwc_bn_bwd( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); - bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); + bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, + coop); return std::vector{x_grad, scale_grad, bias_grad}; } int nhwc_bn_fwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); + int device_id = -1; + cudaGetDevice(&device_id); - //max occupancy supported by the code is 2 - return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2); + // max occupancy supported by the code is 2 + return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2); } int nhwc_bn_bwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2); -} - + int device_id = -1; + cudaGetDevice(&device_id); + // max occupancy supported by the code is 2 + return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2); +} diff --git a/apex/contrib/csrc/groupbn/batch_norm.h b/apex/contrib/csrc/groupbn/batch_norm.h index bb79d6758..e753b4b42 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.h +++ b/apex/contrib/csrc/groupbn/batch_norm.h @@ -22,20 +22,19 @@ * \file nhwc_batch_norm.h * \brief CUDA NHWC Batch Normalization code * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer -*/ + */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ #include #include -#include -#include #include +#include +#include -#include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" - +#include "nhwc_batch_norm_kernel.h" #define VERBOSE_DEFAULT false @@ -57,15 +56,16 @@ class NhwcBatchNorm { exit(-1); } - void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); + void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); + void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void fwdInference(cudaStream_t stream, bool use_relu); - dim3 calc_fwd_grid(int *loop, const int grid_dim_x); - dim3 calc_bwd_grid(int *loop, const int grid_dim_x); + dim3 calc_fwd_grid(int* loop, const int grid_dim_x); + dim3 calc_bwd_grid(int* loop, const int grid_dim_x); - void setInputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, - int n, int c, int h, int w, int bn_group) { + void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w, + int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; c_ = c; @@ -78,40 +78,36 @@ class NhwcBatchNorm { setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } - void setOutputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, - int n, int c, int h, int w) { + void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, + int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } const std::vector numWorkspaceBytes() const; - void setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes); + void setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes); - void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) { + void setInputOutputPointers(void* X, void* dX, void* Y, void* dY) { X_ = X; - dX_ = dX; - Y_ = Y; - dY_ = dY; + dX_ = dX; + Y_ = Y; + dY_ = dY; } // Sets the pointers for the scale and weight (in that order) data and derivative buffers. - void setWeightPointers(const std::vector& weight_pointers, - const std::vector& deriv_pointers) { + void setWeightPointers(const std::vector& weight_pointers, const std::vector& deriv_pointers) { assert(weight_pointers.size() == 2); - assert(deriv_pointers.size() == 2); - scale_ = static_cast(weight_pointers[0]); - bias_ = static_cast(weight_pointers[1]); + assert(deriv_pointers.size() == 2); + scale_ = static_cast(weight_pointers[0]); + bias_ = static_cast(weight_pointers[1]); dscale_ = static_cast(deriv_pointers[0]); - dbias_ = static_cast(deriv_pointers[1]); + dbias_ = static_cast(deriv_pointers[1]); } // Sets the pointers for the population mean and variance buffers, in that order. void setParameterPointers(const std::vector& param_pointers) { assert(param_pointers.size() == 2); - population_mean_ = static_cast(param_pointers[0]); + population_mean_ = static_cast(param_pointers[0]); population_variance_ = static_cast(param_pointers[1]); } @@ -120,8 +116,7 @@ class NhwcBatchNorm { eps_ = eps; } - void processCudnnStatus(const cudnnStatus_t& status, - const std::string& string = std::string(), + void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { if (status != CUDNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); @@ -129,8 +124,7 @@ class NhwcBatchNorm { LOG(INFO) << string << " " << cudnnGetErrorString(status); } - void checkCudaStatus(const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { + void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { cudaError_t status = cudaGetLastError(); if (status != cudaSuccess) LOG(FATAL) << string << " " << cudaGetErrorString(status); @@ -141,34 +135,33 @@ class NhwcBatchNorm { size_t size_retired_ctas(int grid_y) const { // Note that the value of max_grid_y to handle known GPUs is about 160. const int max_grid_y = 1024; - if (grid_y > max_grid_y) - LOG(INFO) << "GPU capabilities exceeds assumptions."; + if (grid_y > max_grid_y) LOG(INFO) << "GPU capabilities exceeds assumptions."; const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); // Since the region will be initialized once and used for many kernels, // the idea is to return an ample size that will cover all uses. return retired_cta_bytes; } - cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; - cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; + cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; + cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; - void* X_ = nullptr; + void* X_ = nullptr; void* dX_ = nullptr; - void* Y_ = nullptr; + void* Y_ = nullptr; void* dY_ = nullptr; // Learned scale and bias weights. - float* scale_ = nullptr; + float* scale_ = nullptr; float* dscale_ = nullptr; - float* bias_ = nullptr; - float* dbias_ = nullptr; + float* bias_ = nullptr; + float* dbias_ = nullptr; // Computed population mean and variance parameters. - float* population_mean_ = nullptr; + float* population_mean_ = nullptr; float* population_variance_ = nullptr; // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). - float* minibatch_mean_ = nullptr; + float* minibatch_mean_ = nullptr; float* minibatch_variance_ = nullptr; int m_ = 0; // Number of values per channel that BN is normalizing. @@ -178,20 +171,18 @@ class NhwcBatchNorm { float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance double exp_avg_factor_ = 0.; - double eps_ = 0.; + double eps_ = 0.; std::string name_; private: - void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, - cudnnTensorFormat_t format, - cudnnDataType_t data_type, + void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type, int n, int c, int h, int w) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); processCudnnStatus(status, "set tensor descriptor"); } - void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) { + void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnCreateTensorDescriptor(descriptor); processCudnnStatus(status, "create tensor_descriptor"); @@ -204,13 +195,13 @@ class NhwcBatchNorm { } protected: - float *partial_sums_ = nullptr; - int *partial_counts_ = nullptr; - int *retired_ctas_ = nullptr; + float* partial_sums_ = nullptr; + int* partial_counts_ = nullptr; + int* retired_ctas_ = nullptr; - void _setFwdParams(NhwcBatchNormFwdParams *params) const; - void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const; - void _setBwdParams(NhwcBatchNormBwdParams *params) const; + void _setFwdParams(NhwcBatchNormFwdParams* params) const; + void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const; + void _setBwdParams(NhwcBatchNormBwdParams* params) const; // @todo: ability to configure these? // Kernel params @@ -222,31 +213,26 @@ class NhwcBatchNorm { static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; typedef uint16_t StorageType; - //typedef float StorageType; - // increasing this to 6 causes spills in fwd kernel! + // typedef float StorageType; + // increasing this to 6 causes spills in fwd kernel! static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; - static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ - PIXELS_PER_THREAD_IN_SMEM_FWD; - static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \ - PIXELS_PER_THREAD_IN_SMEM_BWD; + static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD; + static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD; static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; // Derived params - static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*sizeof(StorageType); - static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*2*sizeof(StorageType); + static const size_t SMEM_SIZE_FWD = + PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType); + static const size_t SMEM_SIZE_BWD = + PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType); static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD; - static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_BWD; - static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD_INFERENCE; + static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD; + static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD; + static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE; // max grid.y in case of group bn is limited by exchange buffer size static const int MAX_GBN_BLOCK_Y = 256; @@ -256,58 +242,31 @@ class NhwcBatchNorm { // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel // version that was compiled with that occupancy in its launch bounds. This way, we avoid // needless register spills. - void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, - dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { - -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - cudaLaunchKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) + void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, + bool use_relu, const int occupancy, const bool coop) { +#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto fwd_func = \ + nhwc_batch_norm_fwd; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void* params_ptr = static_cast(¶ms); \ + using FWD_FUNC = decltype(nhwc_batch_norm_fwd); \ + if (COOP) { \ + cudaLaunchCooperativeKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ + } else { \ + cudaLaunchKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ + } \ + checkCudaStatus(name_ + " fwd ser coop kernel"); \ + } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { @@ -336,99 +295,56 @@ class NhwcBatchNorm { // Helper function to launch the backward kernel. - void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, - dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { -#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_func = nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd coop serial kernel"); \ - } while (0) - -#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ - } while (0) + void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, + bool use_relu, const int occupancy, const bool coop) { +#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto bwd_func = nhwc_batch_norm_bwd; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void* params_ptr = static_cast(¶ms); \ + using BWD_FUNC = \ + decltype(nhwc_batch_norm_bwd); \ + if (COOP) { \ + cudaLaunchCooperativeKernel(bwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ + } else { \ + cudaLaunchKernel(bwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ + } \ + checkCudaStatus(name_ + " bwd coop serial kernel"); \ + } while (0) + +#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto bwd_relu_func = \ + nhwc_batch_norm_bwd_relu; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void* params_ptr = static_cast(¶ms); \ + using BWD_RELU_FUNC = \ + decltype(nhwc_batch_norm_bwd_relu); \ + if (COOP) { \ + cudaLaunchCooperativeKernel(bwd_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, \ + stream); \ + } else { \ + cudaLaunchKernel(bwd_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ + } \ + checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ + } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { @@ -456,11 +372,10 @@ class NhwcBatchNorm { } public: - // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -469,7 +384,7 @@ class NhwcBatchNorm { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -485,109 +400,99 @@ const std::vector NhwcBatchNorm::numWorkspaceBytes() const { int grid_x = max(grid_x_fwd, grid_x_bwd); int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); - const size_t num_mean_bytes = c_ * sizeof(float); + const size_t num_mean_bytes = c_ * sizeof(float); const size_t num_variance_bytes = num_mean_bytes; - const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\ - ELEMENTS_PER_LDG*2*sizeof(float); - const size_t size_counts = grid_y*grid_x*sizeof(int); + const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float); + const size_t size_counts = grid_y * grid_x * sizeof(int); - return {num_mean_bytes, num_variance_bytes, - size_retired_ctas(grid_y), size_sums, size_counts}; + return {num_mean_bytes, num_variance_bytes, size_retired_ctas(grid_y), size_sums, size_counts}; } -void NhwcBatchNorm::setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes) { +void NhwcBatchNorm::setWorkspacePointers(const std::vector& workspace, + const std::vector& num_workspace_bytes) { assert(workspace.size() == 5); assert(num_workspace_bytes.size() == 5); - minibatch_mean_ = static_cast(workspace[0]); + minibatch_mean_ = static_cast(workspace[0]); minibatch_variance_ = static_cast(workspace[1]); - retired_ctas_ = static_cast(workspace[2]); - partial_sums_ = static_cast(workspace[3]); - partial_counts_ = static_cast(workspace[4]); + retired_ctas_ = static_cast(workspace[2]); + partial_sums_ = static_cast(workspace[3]); + partial_counts_ = static_cast(workspace[4]); } -void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = nullptr; - params->gmem_bias = bias_; - params->gmem_scale = scale_; +void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dst = static_cast(Y_); + params->gmem_src1 = nullptr; + params->gmem_bias = bias_; + params->gmem_scale = scale_; params->gmem_running_mean = population_mean_; - params->gmem_running_var = population_variance_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; + params->gmem_running_var = population_variance_; + params->gmem_saved_mean = minibatch_mean_; + params->gmem_saved_var = minibatch_variance_; params->gmem_relu_bitmask = nullptr; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->rvar_inv_count = rvar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_counts = partial_counts_; + params->nhw = m_; + params->c = c_; + params->svar_inv_count = svar_inv_count_; + params->rvar_inv_count = rvar_inv_count_; + params->gmem_sums = partial_sums_; + params->gmem_counts = partial_counts_; params->gmem_retired_ctas = retired_ctas_; - params->var_eps = eps_; - params->outer_loops = 0; - params->exp_avg_factor = static_cast(exp_avg_factor_); - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); + params->var_eps = eps_; + params->outer_loops = 0; + params->exp_avg_factor = static_cast(exp_avg_factor_); + params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } -void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams - *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = nullptr; - params->gmem_bias = bias_; +void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dst = static_cast(Y_); + params->gmem_src1 = nullptr; + params->gmem_bias = bias_; params->gmem_scale = scale_; - params->gmem_mean = population_mean_; - params->gmem_var = population_variance_; - params->nhw = m_; - params->c = c_; - params->var_eps = eps_; + params->gmem_mean = population_mean_; + params->gmem_var = population_variance_; + params->nhw = m_; + params->c = c_; + params->var_eps = eps_; } -void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dy = static_cast(dY_); - params->gmem_dst = static_cast(dX_); - params->gmem_dst1 = nullptr; +void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dy = static_cast(dY_); + params->gmem_dst = static_cast(dX_); + params->gmem_dst1 = nullptr; params->gmem_relu_bitmask = nullptr; - params->gmem_dscale = dscale_; - params->gmem_dbias = dbias_; - params->gmem_scale = scale_; - params->gmem_bias = bias_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->gmem_sums = partial_sums_; + params->gmem_dscale = dscale_; + params->gmem_dbias = dbias_; + params->gmem_scale = scale_; + params->gmem_bias = bias_; + params->gmem_saved_mean = minibatch_mean_; + params->gmem_saved_var = minibatch_variance_; + params->nhw = m_; + params->c = c_; + params->svar_inv_count = svar_inv_count_; + params->gmem_sums = partial_sums_; params->gmem_retired_ctas = retired_ctas_; - params->outer_loops = 0; - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); + params->outer_loops = 0; + params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - // && minibatch_mean_ != nullptr - // && minibatch_variance_ != nullptr - && population_mean_ != nullptr - && population_variance_ != nullptr - && X_ != nullptr - // && dX_ != nullptr - && Y_ != nullptr - // && dY_ != nullptr - // && dscale_ != nullptr - // && dbias_ != nullptr - && partial_sums_ != nullptr - && partial_counts_ != nullptr; - - if (!ptrs_are_set) - die(); + bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && + bias_ != nullptr + // && minibatch_mean_ != nullptr + // && minibatch_variance_ != nullptr + && population_mean_ != nullptr && population_variance_ != nullptr && + X_ != nullptr + // && dX_ != nullptr + && Y_ != nullptr + // && dY_ != nullptr + // && dscale_ != nullptr + // && dbias_ != nullptr + && partial_sums_ != nullptr && partial_counts_ != nullptr; + + if (!ptrs_are_set) die(); dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); @@ -598,19 +503,17 @@ void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) { _setFwdInferenceParams(¶ms); if (use_relu) { - nhwc_batch_norm_fwd_inference - - <<>>(params); + nhwc_batch_norm_fwd_inference + <<>>(params); checkCudaStatus(name_ + " fwd_inference-relu kernel"); } else { - nhwc_batch_norm_fwd_inference - - <<>>(params); + nhwc_batch_norm_fwd_inference + <<>>(params); checkCudaStatus(name_ + " fwd_inference kernel"); } } -dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) { +dim3 NhwcBatchNorm::calc_fwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); @@ -619,21 +522,21 @@ dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y> 1); + params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); } -void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, - const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && (bias_ != nullptr || !use_relu) - && minibatch_mean_ != nullptr - && minibatch_variance_ != nullptr - // && population_mean_ != nullptr - // && population_variance_ != nullptr - && X_ != nullptr - && dX_ != nullptr - // && Y_ != nullptr - && dY_ != nullptr - && dscale_ != nullptr - && dbias_ != nullptr; - - if (!ptrs_are_set) - die(); +void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, + void* pair_data3, const int bn_group, const int magic, const int occupancy, + const int grid_dim_x, const bool coop) { + bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && + (bias_ != nullptr || !use_relu) && minibatch_mean_ != nullptr && + minibatch_variance_ != nullptr + // && population_mean_ != nullptr + // && population_variance_ != nullptr + && X_ != nullptr && + dX_ != nullptr + // && Y_ != nullptr + && dY_ != nullptr && dscale_ != nullptr && dbias_ != nullptr; + + if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed @@ -725,7 +614,7 @@ void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, voi params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; - params.sync_iters = (bn_group==8)?3:(bn_group >> 1); + params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); params.wgrad_coeff = 1.0 / bn_group; dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu index 713ca4dbf..340c149b0 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu @@ -1,27 +1,22 @@ #include #include #include +#include #include "batch_norm_add_relu.h" -#include +// FIXME move the common stuff to common h file +#define cudaCheckErrors(msg) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ + fprintf(stderr, "*** FAILED - ABORTING\n"); \ + exit(1); \ + } \ + } while (0) -//FIXME move the common stuff to common h file -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -static size_t round_up_to_multiple(size_t x, int multiple) { - return ((x + multiple - 1) / multiple) * multiple; -} +static size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; } struct Workspace { Workspace(size_t size) : size(size), data(NULL) { @@ -40,29 +35,14 @@ struct Workspace { }; // Return {y} -at::Tensor nhwc_bn_addrelu_fwd_train( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - +at::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, const float momentum, const float epsilon, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const at::Tensor& magic_tensor, const int occupancy, + const int grid_dim_x, const bool coop) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); @@ -76,7 +56,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train( at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); + NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -84,11 +64,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - nullptr, - y.data_ptr(), - nullptr, - z.data_ptr(), + bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr, z.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); @@ -112,7 +88,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); workspace.push_back(bitmask.data_ptr()); @@ -120,12 +96,12 @@ at::Tensor nhwc_bn_addrelu_fwd_train( auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } @@ -137,18 +113,10 @@ at::Tensor nhwc_bn_addrelu_fwd_train( return y; } -at::Tensor nhwc_bn_addrelu_fwd_eval( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon) { - +at::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, + const float momentum, const float epsilon) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); @@ -158,7 +126,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); + NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -166,11 +134,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - nullptr, - y.data_ptr(), - nullptr, - z.data_ptr(), + bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr, z.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); @@ -194,7 +158,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(nullptr); workspace.push_back(nullptr); workspace.push_back(nullptr); @@ -202,11 +166,11 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } @@ -216,31 +180,16 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( bn->fwdInference(stream); return y; - } -std::vector nhwc_bn_addrelu_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { +std::vector nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, const float momentum, const float epsilon, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const at::Tensor& magic_tensor, const int occupancy, + const int grid_dim_x, const bool coop) { // shape const int N = x.size(0); const int H = x.size(1); @@ -261,7 +210,7 @@ std::vector nhwc_bn_addrelu_bwd( bias_grad = at::empty_like(bias); // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); + NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); @@ -269,14 +218,11 @@ std::vector nhwc_bn_addrelu_bwd( bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.data_ptr(), - x_grad.data_ptr(), - nullptr, - dy.data_ptr(), - nullptr, - z_grad.data_ptr()); - - bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {scale_grad.data_ptr(), bias_grad.data_ptr()}); + bn->setInputOutputPointers(x.data_ptr(), x_grad.data_ptr(), nullptr, dy.data_ptr(), + nullptr, z_grad.data_ptr()); + + bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, + {scale_grad.data_ptr(), bias_grad.data_ptr()}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) @@ -297,7 +243,7 @@ std::vector nhwc_bn_addrelu_bwd( // Allocate the workspace Workspace ws(total_workspace_bytes); - std::vector workspace; + std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); workspace.push_back(bitmask.data_ptr()); @@ -305,11 +251,11 @@ std::vector nhwc_bn_addrelu_bwd( auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); - assert(ret_cta.size(0)>=retired_cta_bytes); + assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; + void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } @@ -321,18 +267,17 @@ std::vector nhwc_bn_addrelu_bwd( } int nhwc_bn_addrelu_fwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2); + int device_id = -1; + cudaGetDevice(&device_id); + + // max occupancy supported by the code is 2 + return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2); } int nhwc_bn_addrelu_bwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); + int device_id = -1; + cudaGetDevice(&device_id); - //max occupancy supported by the code is 2 - return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2); + // max occupancy supported by the code is 2 + return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2); } - diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h index 3dfe7b269..788d62ef9 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h @@ -22,20 +22,19 @@ * \file nhwc_batch_norm_add_relu.h * \brief CUDA NHWC Batch Normalization code with fused addition * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer -*/ + */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ #include #include -#include -#include #include +#include +#include -#include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" - +#include "nhwc_batch_norm_kernel.h" #define VERBOSE_DEFAULT false @@ -57,15 +56,16 @@ class NhwcBatchNormAddRelu { exit(-1); } - void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); + void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, + const int magic, const int occupancy, const int grid_dim_x, const bool coop); + void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void fwdInference(cudaStream_t stream); - dim3 calc_fwd_grid(int *loop, const int grid_dim_x); - dim3 calc_bwd_grid(int *loop, const int grid_dim_x); + dim3 calc_fwd_grid(int* loop, const int grid_dim_x); + dim3 calc_bwd_grid(int* loop, const int grid_dim_x); - void setInputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, - int n, int c, int h, int w, int bn_group) { + void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w, + int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; c_ = c; @@ -78,42 +78,38 @@ class NhwcBatchNormAddRelu { setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } - void setOutputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, - int n, int c, int h, int w) { + void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, + int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } const std::vector numWorkspaceBytes() const; - void setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes); + void setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes); - void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) { + void setInputOutputPointers(void* X, void* dX, void* Y, void* dY, void* addend, void* dAddend) { X_ = X; - dX_ = dX; - Y_ = Y; - dY_ = dY; - addend_ = addend; - dAddend_ = dAddend; + dX_ = dX; + Y_ = Y; + dY_ = dY; + addend_ = addend; + dAddend_ = dAddend; } // Sets the pointers for the scale and weight (in that order) data and derivative buffers. - void setWeightPointers(const std::vector& weight_pointers, - const std::vector& deriv_pointers) { + void setWeightPointers(const std::vector& weight_pointers, const std::vector& deriv_pointers) { assert(weight_pointers.size() == 2); - assert(deriv_pointers.size() == 2); - scale_ = static_cast(weight_pointers[0]); - bias_ = static_cast(weight_pointers[1]); + assert(deriv_pointers.size() == 2); + scale_ = static_cast(weight_pointers[0]); + bias_ = static_cast(weight_pointers[1]); dscale_ = static_cast(deriv_pointers[0]); - dbias_ = static_cast(deriv_pointers[1]); + dbias_ = static_cast(deriv_pointers[1]); } // Sets the pointers for the population mean and variance buffers, in that order. void setParameterPointers(const std::vector& param_pointers) { assert(param_pointers.size() == 2); - population_mean_ = static_cast(param_pointers[0]); + population_mean_ = static_cast(param_pointers[0]); population_variance_ = static_cast(param_pointers[1]); } @@ -122,8 +118,7 @@ class NhwcBatchNormAddRelu { eps_ = eps; } - void processCudnnStatus(const cudnnStatus_t& status, - const std::string& string = std::string(), + void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { if (status != CUDNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); @@ -131,8 +126,7 @@ class NhwcBatchNormAddRelu { LOG(INFO) << string << " " << cudnnGetErrorString(status); } - void checkCudaStatus(const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { + void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { cudaError_t status = cudaGetLastError(); if (status != cudaSuccess) LOG(FATAL) << string << " " << cudaGetErrorString(status); @@ -143,36 +137,35 @@ class NhwcBatchNormAddRelu { size_t size_retired_ctas(int grid_y) const { // Note that the value of max_grid_y to handle known GPUs is about 160. const int max_grid_y = 1024; - if (grid_y > max_grid_y) - LOG(INFO) << "GPU capabilities exceeds assumptions."; + if (grid_y > max_grid_y) LOG(INFO) << "GPU capabilities exceeds assumptions."; const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); // Since the region will be initialized once and used for many kernels, // the idea is to return an ample size that will cover all uses. return retired_cta_bytes; } - cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; - cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; + cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; + cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; - void* X_ = nullptr; + void* X_ = nullptr; void* dX_ = nullptr; - void* Y_ = nullptr; + void* Y_ = nullptr; void* dY_ = nullptr; - void* addend_ = nullptr; + void* addend_ = nullptr; void* dAddend_ = nullptr; // Learned scale and bias weights. - float* scale_ = nullptr; + float* scale_ = nullptr; float* dscale_ = nullptr; - float* bias_ = nullptr; - float* dbias_ = nullptr; + float* bias_ = nullptr; + float* dbias_ = nullptr; // Computed population mean and variance parameters. - float* population_mean_ = nullptr; + float* population_mean_ = nullptr; float* population_variance_ = nullptr; // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). - float* minibatch_mean_ = nullptr; + float* minibatch_mean_ = nullptr; float* minibatch_variance_ = nullptr; int m_ = 0; // Number of values per channel that BN is normalizing. @@ -182,20 +175,18 @@ class NhwcBatchNormAddRelu { float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance double exp_avg_factor_ = 0.; - double eps_ = 0.; + double eps_ = 0.; std::string name_; private: - void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, - cudnnTensorFormat_t format, - cudnnDataType_t data_type, + void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type, int n, int c, int h, int w) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); processCudnnStatus(status, "set tensor descriptor"); } - void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) { + void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnCreateTensorDescriptor(descriptor); processCudnnStatus(status, "create tensor_descriptor"); @@ -208,14 +199,14 @@ class NhwcBatchNormAddRelu { } protected: - float *partial_sums_ = nullptr; - int *partial_counts_ = nullptr; - int *retired_ctas_ = nullptr; - unsigned int *relu_bitmask_ = nullptr; + float* partial_sums_ = nullptr; + int* partial_counts_ = nullptr; + int* retired_ctas_ = nullptr; + unsigned int* relu_bitmask_ = nullptr; - void _setFwdParams(NhwcBatchNormFwdParams *params) const; - void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const; - void _setBwdParams(NhwcBatchNormBwdParams *params) const; + void _setFwdParams(NhwcBatchNormFwdParams* params) const; + void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const; + void _setBwdParams(NhwcBatchNormBwdParams* params) const; // @todo: ability to configure these? // Kernel params @@ -233,24 +224,19 @@ class NhwcBatchNormAddRelu { static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; - static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ - PIXELS_PER_THREAD_IN_SMEM_FWD; - static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \ - PIXELS_PER_THREAD_IN_SMEM_BWD; + static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD; + static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD; static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; // Derived params - static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*sizeof(StorageType); - static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*2*sizeof(StorageType); + static const size_t SMEM_SIZE_FWD = + PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType); + static const size_t SMEM_SIZE_BWD = + PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType); static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD; - static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_BWD; - static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD_INFERENCE; + static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD; + static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD; + static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE; // max grid.y in case of group bn is limited by exchange buffer size static const int MAX_GBN_BLOCK_Y = 256; @@ -260,58 +246,31 @@ class NhwcBatchNormAddRelu { // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel // version that was compiled with that occupancy in its launch bounds. This way, we avoid // needless register spills. - void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, - dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - cudaLaunchKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) + void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, + const int occupancy, const bool coop) { +#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnormaddrelu kernel smem too big."; \ + auto fwd_func = \ + nhwc_batch_norm_fwd; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void* params_ptr = static_cast(¶ms); \ + using FWD_FUNC = decltype(nhwc_batch_norm_fwd); \ + if (COOP) { \ + cudaLaunchCooperativeKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ + } else { \ + cudaLaunchKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ + } \ + checkCudaStatus(name_ + " fwd ser coop kernel"); \ + } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1) { @@ -330,55 +289,33 @@ class NhwcBatchNormAddRelu { // Helper function to launch the backward kernel. - void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, - dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { -#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_add_relu_func, \ - cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + \ - " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ + void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, + const int occupancy, const bool coop) { +#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnormaddrelu kernel smem too big."; \ + auto bwd_add_relu_func = \ + nhwc_batch_norm_bwd_add_relu; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + cudaFuncSetAttribute(bwd_add_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void* params_ptr = static_cast(¶ms); \ + using BWD_ADD_RELU_FUNC = \ + decltype(nhwc_batch_norm_bwd_add_relu); \ + if (COOP) { \ + cudaLaunchCooperativeKernel(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, \ + SMEM_SIZE_BWD, stream); \ + } else { \ + cudaLaunchKernel(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. @@ -400,7 +337,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -409,7 +346,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -425,116 +362,106 @@ const std::vector NhwcBatchNormAddRelu::numWorkspaceBytes() const { int grid_x = max(grid_x_fwd, grid_x_bwd); int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); - const size_t num_mean_bytes = c_ * sizeof(float); + const size_t num_mean_bytes = c_ * sizeof(float); const size_t num_variance_bytes = num_mean_bytes; int elems_per_group = ((m_ + 31) & ~31) * 2; int group_count = div_up(c_, C_ELEMENTS_PER_CTA); const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int); - const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\ - ELEMENTS_PER_LDG*2*sizeof(float); - const size_t size_counts = grid_y*grid_x*sizeof(int); + const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float); + const size_t size_counts = grid_y * grid_x * sizeof(int); - return {num_mean_bytes, num_variance_bytes, bitmask_bytes, - size_retired_ctas(grid_y), size_sums, size_counts}; + return {num_mean_bytes, num_variance_bytes, bitmask_bytes, size_retired_ctas(grid_y), size_sums, size_counts}; } -void NhwcBatchNormAddRelu::setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes) { +void NhwcBatchNormAddRelu::setWorkspacePointers(const std::vector& workspace, + const std::vector& num_workspace_bytes) { assert(workspace.size() == 6); assert(num_workspace_bytes.size() == 6); - minibatch_mean_ = static_cast(workspace[0]); + minibatch_mean_ = static_cast(workspace[0]); minibatch_variance_ = static_cast(workspace[1]); - relu_bitmask_ = static_cast(workspace[2]); - retired_ctas_ = static_cast(workspace[3]); - partial_sums_ = static_cast(workspace[4]); - partial_counts_ = static_cast(workspace[5]); + relu_bitmask_ = static_cast(workspace[2]); + retired_ctas_ = static_cast(workspace[3]); + partial_sums_ = static_cast(workspace[4]); + partial_counts_ = static_cast(workspace[5]); } -void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = static_cast(addend_); - params->gmem_bias = bias_; - params->gmem_scale = scale_; +void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dst = static_cast(Y_); + params->gmem_src1 = static_cast(addend_); + params->gmem_bias = bias_; + params->gmem_scale = scale_; params->gmem_running_mean = population_mean_; - params->gmem_running_var = population_variance_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; + params->gmem_running_var = population_variance_; + params->gmem_saved_mean = minibatch_mean_; + params->gmem_saved_var = minibatch_variance_; params->gmem_relu_bitmask = relu_bitmask_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->rvar_inv_count = rvar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_counts = partial_counts_; + params->nhw = m_; + params->c = c_; + params->svar_inv_count = svar_inv_count_; + params->rvar_inv_count = rvar_inv_count_; + params->gmem_sums = partial_sums_; + params->gmem_counts = partial_counts_; params->gmem_retired_ctas = retired_ctas_; - params->var_eps = eps_; - params->outer_loops = 0; - params->exp_avg_factor = static_cast(exp_avg_factor_); - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); + params->var_eps = eps_; + params->outer_loops = 0; + params->exp_avg_factor = static_cast(exp_avg_factor_); + params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } -void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams - *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = static_cast(addend_); - params->gmem_bias = bias_; +void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dst = static_cast(Y_); + params->gmem_src1 = static_cast(addend_); + params->gmem_bias = bias_; params->gmem_scale = scale_; - params->gmem_mean = population_mean_; - params->gmem_var = population_variance_; - params->nhw = m_; - params->c = c_; - params->var_eps = eps_; + params->gmem_mean = population_mean_; + params->gmem_var = population_variance_; + params->nhw = m_; + params->c = c_; + params->var_eps = eps_; } -void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dy = static_cast(dY_); - params->gmem_dst = static_cast(dX_); - params->gmem_dst1 = static_cast(dAddend_); +void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams* params) const { + params->gmem_src = static_cast(X_); + params->gmem_dy = static_cast(dY_); + params->gmem_dst = static_cast(dX_); + params->gmem_dst1 = static_cast(dAddend_); params->gmem_relu_bitmask = relu_bitmask_; - params->gmem_dscale = dscale_; - params->gmem_dbias = dbias_; - params->gmem_scale = scale_; - params->gmem_bias = bias_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->gmem_sums = partial_sums_; + params->gmem_dscale = dscale_; + params->gmem_dbias = dbias_; + params->gmem_scale = scale_; + params->gmem_bias = bias_; + params->gmem_saved_mean = minibatch_mean_; + params->gmem_saved_var = minibatch_variance_; + params->nhw = m_; + params->c = c_; + params->svar_inv_count = svar_inv_count_; + params->gmem_sums = partial_sums_; params->gmem_retired_ctas = retired_ctas_; - params->outer_loops = 0; - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); + params->outer_loops = 0; + params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - // && minibatch_mean_ != nullptr - // && minibatch_variance_ != nullptr - && population_mean_ != nullptr - && population_variance_ != nullptr - && X_ != nullptr - // && dX_ != nullptr - && Y_ != nullptr - && addend_ != nullptr - // && dY_ != nullptr - // && dscale_ != nullptr - // && dbias_ != nullptr - && partial_sums_ != nullptr - && partial_counts_ != nullptr; - - if (!ptrs_are_set) - die(); + bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && + bias_ != nullptr + // && minibatch_mean_ != nullptr + // && minibatch_variance_ != nullptr + && population_mean_ != nullptr && population_variance_ != nullptr && + X_ != nullptr + // && dX_ != nullptr + && Y_ != nullptr && + addend_ != nullptr + // && dY_ != nullptr + // && dscale_ != nullptr + // && dbias_ != nullptr + && partial_sums_ != nullptr && partial_counts_ != nullptr; + + if (!ptrs_are_set) die(); dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); @@ -544,13 +471,12 @@ void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) { NhwcBatchNormFwdInferenceParams params; _setFwdInferenceParams(¶ms); - nhwc_batch_norm_fwd_inference - - <<>>(params); + nhwc_batch_norm_fwd_inference + <<>>(params); checkCudaStatus(name_ + " fwd_inference-relu kernel"); } -dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) { +dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); @@ -559,21 +485,21 @@ dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y> 1); + params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop); } -void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, - const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - && minibatch_mean_ != nullptr - && minibatch_variance_ != nullptr - && relu_bitmask_ != nullptr - // && population_mean_ != nullptr - // && population_variance_ != nullptr - && X_ != nullptr - && dX_ != nullptr - // && Y_ != nullptr - && dY_ != nullptr - && dAddend_ != nullptr - && dscale_ != nullptr - && dbias_ != nullptr - && retired_ctas_ != nullptr; - - if (!ptrs_are_set) - die(); +void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, + void* pair_data3, const int bn_group, const int magic, const int occupancy, + const int grid_dim_x, const bool coop) { + bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr && + minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && + relu_bitmask_ != nullptr + // && population_mean_ != nullptr + // && population_variance_ != nullptr + && X_ != nullptr && + dX_ != nullptr + // && Y_ != nullptr + && dY_ != nullptr && dAddend_ != nullptr && dscale_ != nullptr && dbias_ != nullptr && + retired_ctas_ != nullptr; + + if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed @@ -672,7 +581,7 @@ void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_ params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; - params.sync_iters = (bn_group==8)?3:(bn_group >> 1); + params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); params.wgrad_coeff = 1.0 / bn_group; dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); diff --git a/apex/contrib/csrc/groupbn/cuda_utils.h b/apex/contrib/csrc/groupbn/cuda_utils.h index 9f003840c..7058a17b9 100644 --- a/apex/contrib/csrc/groupbn/cuda_utils.h +++ b/apex/contrib/csrc/groupbn/cuda_utils.h @@ -8,13 +8,11 @@ namespace cuda { namespace utils { static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { - return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; -} - - -} -} + return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; } +} // namespace utils +} // namespace cuda +} // namespace at #endif diff --git a/apex/contrib/csrc/groupbn/interface.cpp b/apex/contrib/csrc/groupbn/interface.cpp index cb0012d84..27891bce1 100644 --- a/apex/contrib/csrc/groupbn/interface.cpp +++ b/apex/contrib/csrc/groupbn/interface.cpp @@ -1,146 +1,70 @@ -#include -#include -#include - -#include #include #include #include +#include +#include +#include +#include + +#include "ATen/Generator.h" #include "ATen/Scalar.h" -#include "ATen/Tensor.h" #include "ATen/Storage.h" -#include "ATen/Generator.h" - +#include "ATen/Tensor.h" namespace py = pybind11; -int64_t get_buffer_size( - const int bn_sync_steps); - -void* get_data_ptr( - const at::Tensor& data); - -void* get_remote_data_ptr( - const at::Tensor& handle, - const int64_t offset); - -void close_remote_data( - const at::Tensor& handle); - -at::Tensor nhwc_bn_fwd_train( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_fwd_eval( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon, - const bool fuse_relu); - -std::vector nhwc_bn_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_addrelu_fwd_train( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_addrelu_fwd_eval( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon); - -std::vector nhwc_bn_addrelu_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); +int64_t get_buffer_size(const int bn_sync_steps); + +void* get_data_ptr(const at::Tensor& data); + +void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset); + +void close_remote_data(const at::Tensor& handle); + +at::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, + const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, + const bool coop); + +at::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon, + const bool fuse_relu); + +std::vector nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, + const float momentum, const float epsilon, const bool fuse_relu, void* my_data, + void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, + const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, + const bool coop); + +at::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, const float momentum, const float epsilon, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const at::Tensor& magic_tensor, const int occupancy, + const int grid_dim_x, const bool coop); + +at::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, + const float momentum, const float epsilon); + +std::vector nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, const float momentum, const float epsilon, + void* my_data, void* pair_data, void* pair_data2, void* pair_data3, + const int bn_group, const at::Tensor& magic_tensor, const int occupancy, + const int grid_dim_x, const bool coop); int nhwc_bn_fwd_occupancy(); int nhwc_bn_bwd_occupancy(); @@ -149,7 +73,6 @@ int nhwc_bn_addrelu_fwd_occupancy(); int nhwc_bn_addrelu_bwd_occupancy(); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_buffer_size", &get_buffer_size, "get_buffer_size", py::call_guard()); m.def("get_data_ptr", &get_data_ptr, "get_data_ptr", py::call_guard()); m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr", py::call_guard()); @@ -159,14 +82,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc", py::call_guard()); m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc", py::call_guard()); - m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy", py::call_guard()); - m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy", py::call_guard()); + m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy", + py::call_guard()); + m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy", + py::call_guard()); - m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc", py::call_guard()); - m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc", py::call_guard()); + m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc", + py::call_guard()); + m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc", + py::call_guard()); m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc", py::call_guard()); - m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy", py::call_guard()); - m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy", py::call_guard()); + m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy", + py::call_guard()); + m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy", + py::call_guard()); } - diff --git a/apex/contrib/csrc/groupbn/ipc.cu b/apex/contrib/csrc/groupbn/ipc.cu index c31c094dd..699df9f37 100644 --- a/apex/contrib/csrc/groupbn/ipc.cu +++ b/apex/contrib/csrc/groupbn/ipc.cu @@ -1,28 +1,24 @@ #include #include - #include - -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -template<> +#define cudaCheckErrors(msg) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ + fprintf(stderr, "*** FAILED - ABORTING\n"); \ + exit(1); \ + } \ + } while (0) + +template <> struct std::hash { - size_t operator() (const cudaIpcMemHandle_t& handle) const { + size_t operator()(const cudaIpcMemHandle_t& handle) const { size_t hash = 0; uint8_t* ptr = (uint8_t*)&handle; assert(sizeof(uint8_t) == 1); - for (int i=0; i { } }; -template<> +template <> struct std::equal_to { - bool operator() (const cudaIpcMemHandle_t &lhs, - const cudaIpcMemHandle_t &rhs) const { - return (std::memcmp((void*) &lhs, - (void*) &rhs, - sizeof(cudaIpcMemHandle_t)) == 0); + bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const { + return (std::memcmp((void*)&lhs, (void*)&rhs, sizeof(cudaIpcMemHandle_t)) == 0); } }; namespace { namespace gpuipc { -//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h -// The number of threads per pixel. +// from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h +// The number of threads per pixel. const int THREADS_PER_PIXEL = 16; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; @@ -52,14 +45,14 @@ const int ELEMENTS_PER_LDG = 4; const int REDUCE_OPS = 4; // Maximum block.y supported - limited due to buffer allocation const int MAX_BLOCK_Y = 256; -const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; +const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y; const int BYTES_PER_ELEM = 4; // Buffer size per sync step -const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM; -}; +const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET * THREADS_PER_PIXEL * 2 * ELEMENTS_PER_LDG * BYTES_PER_ELEM; +}; // namespace gpuipc class IpcMemHandleRegistry { -public: + public: void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) { if (registry_.count(handle) == 0) { registry_.insert(std::make_pair(handle, RegistryEntry())); @@ -80,15 +73,15 @@ public: struct RegistryEntry { void* dev_ptr; - int ref_count; - RegistryEntry() : dev_ptr(NULL) , ref_count(0) {} + int ref_count; + RegistryEntry() : dev_ptr(NULL), ref_count(0) {} }; -protected: + protected: std::unordered_map registry_; void* ipcOpenMem(const cudaIpcMemHandle_t& handle) { - void *data; + void* data; cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess); cudaCheckErrors("ipc init"); return data; @@ -98,30 +91,24 @@ protected: cudaIpcCloseMemHandle(dev_ptr); cudaCheckErrors("ipc close"); } - }; -} +} // namespace static IpcMemHandleRegistry ipc_mem_registry; -int64_t get_buffer_size(const int bn_sync_steps) { - return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; -} +int64_t get_buffer_size(const int bn_sync_steps) { return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; } void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) { cudaIpcMemHandle_t my_handle; - memcpy((unsigned char *)(&my_handle), handle.data_ptr(), sizeof(my_handle)); + memcpy((unsigned char*)(&my_handle), handle.data_ptr(), sizeof(my_handle)); return ipc_mem_registry.getPtr(my_handle, offset); } void close_remote_data(const at::Tensor& handle) { - cudaIpcMemHandle_t my_handle; - memcpy((unsigned char *)(&my_handle), handle.data_ptr(), sizeof(my_handle)); + cudaIpcMemHandle_t my_handle; + memcpy((unsigned char*)(&my_handle), handle.data_ptr(), sizeof(my_handle)); ipc_mem_registry.releasePtr(my_handle); } -void* get_data_ptr( - const at::Tensor& data) { - return data.data_ptr(); -} +void* get_data_ptr(const at::Tensor& data) { return data.data_ptr(); } diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index 8430f3099..19f50bf31 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -22,2664 +22,2562 @@ * \file nhwc_batch_norm_kernel.h * \brief CUDA NHWC Batch Normalization code * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer -*/ + */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ #include + #include #define DEVICE_FUNCTION static inline __device__ // CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN. -#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3 +#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3 #define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename T, int ELEMENTS_PER_LDG > +template struct PackedStorage { - enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG }; - typedef T Type; + enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG }; + typedef T Type; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int ELEMENTS_PER_LDG > +template struct PackedStorage { - enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 }; - typedef int Type; + enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG / 2 }; + typedef int Type; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - uint16_t lo, hi; - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0])); - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1])); - asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi)); - } +template +DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2 * N]) { +#pragma unroll + for (int i = 0; i < N; ++i) { + uint16_t lo, hi; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2 * i + 0])); + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2 * i + 1])); + asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi)); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = src[i]; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = src[i]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > -DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i])); - asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo)); - asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi)); - } +template +DEVICE_FUNCTION void to_float(float (&dst)[2 * N], int (&src)[N]) { +#pragma unroll + for (int i = 0; i < N; ++i) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i])); + asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2 * i + 0]) : "h"(lo)); + asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2 * i + 1]) : "h"(hi)); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = src[i]; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = src[i]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) { - dst[0] = __ldg((const int*) gmem); -} +DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) { dst[0] = __ldg((const int *)gmem); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) { - unsigned int tmp; - asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem)); - dst[0] = tmp; + unsigned int tmp; + asm volatile("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l"((const uint *)gmem)); + dst[0] = tmp; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) { - int2 tmp = __ldg((const int2*) gmem); - dst[0] = tmp.x; - dst[1] = tmp.y; + int2 tmp = __ldg((const int2 *)gmem); + dst[0] = tmp.x; + dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) { - int2 tmp; - asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];" - : "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem)); - dst[0] = tmp.x; - dst[1] = tmp.y; + int2 tmp; + asm volatile("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];" : "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem)); + dst[0] = tmp.x; + dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) { - int tmp[N/2]; - ldg(tmp, gmem); - to_float(dst, tmp); + int tmp[N / 2]; + ldg(tmp, gmem); + to_float(dst, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) { - int tmp[N/2]; - ldg_stream(tmp, gmem); - to_float(dst, tmp); + int tmp[N / 2]; + ldg_stream(tmp, gmem); + to_float(dst, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// -DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) { - reinterpret_cast(gmem)[0] = src[0]; -} +DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) { reinterpret_cast(gmem)[0] = src[0]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) { - unsigned int tmp = src[0]; - asm volatile ("st.global.cs.s32 [%0], %1;" - :: "l"((uint *)gmem) , "r"(tmp)); + unsigned int tmp = src[0]; + asm volatile("st.global.cs.s32 [%0], %1;" ::"l"((uint *)gmem), "r"(tmp)); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) { - reinterpret_cast(gmem)[0] = make_int2(src[0], src[1]); + reinterpret_cast(gmem)[0] = make_int2(src[0], src[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) { - asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};" - :: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1])); + asm volatile("st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"((uint *)gmem), "r"(src[0]), "r"(src[1])); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) { - int tmp[N/2]; - from_float(tmp, src); - stg(gmem, tmp); + int tmp[N / 2]; + from_float(tmp, src); + stg(gmem, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) { - int tmp[N/2]; - from_float(tmp, src); - stg_stream(gmem, tmp); + int tmp[N / 2]; + from_float(tmp, src); + stg_stream(gmem, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) { - float2 tmp = __ldg(reinterpret_cast(&gmem[2*idx])); - dst[0] = tmp.x; - dst[1] = tmp.y; + float2 tmp = __ldg(reinterpret_cast(&gmem[2 * idx])); + dst[0] = tmp.x; + dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) { - float4 tmp = __ldg(reinterpret_cast(&gmem[4*idx])); - dst[0] = tmp.x; - dst[1] = tmp.y; - dst[2] = tmp.z; - dst[3] = tmp.w; + float4 tmp = __ldg(reinterpret_cast(&gmem[4 * idx])); + dst[0] = tmp.x; + dst[1] = tmp.y; + dst[2] = tmp.z; + dst[3] = tmp.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) { - float2 tmp = *(const float2*) &smem[2*idx]; - x[0] = tmp.x; - x[1] = tmp.y; + float2 tmp = *(const float2 *)&smem[2 * idx]; + x[0] = tmp.x; + x[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// -DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) { - x[0] = smem[idx]; -} +DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) { x[0] = smem[idx]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) { - float4 tmp = *(const float4*) &smem[4*idx]; - x[0] = tmp.x; - x[1] = tmp.y; - x[2] = tmp.z; - x[3] = tmp.w; + float4 tmp = *(const float4 *)&smem[4 * idx]; + x[0] = tmp.x; + x[1] = tmp.y; + x[2] = tmp.z; + x[3] = tmp.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) { - int2 tmp = *(const int2*) &smem[2*idx]; - x[0] = tmp.x; - x[1] = tmp.y; + int2 tmp = *(const int2 *)&smem[2 * idx]; + x[0] = tmp.x; + x[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) { - reinterpret_cast(&gmem[2*idx])[0] = make_float2(src[0], src[1]); + reinterpret_cast(&gmem[2 * idx])[0] = make_float2(src[0], src[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) { - reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]); + reinterpret_cast(&gmem[4 * idx])[0] = make_float4(src[0], src[1], src[2], src[3]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) { - reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff); + reinterpret_cast(&gmem[4 * idx])[0] = + make_float4(src[0] * coeff, src[1] * coeff, src[2] * coeff, src[3] * coeff); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) { - reinterpret_cast(&smem[2*idx])[0] = make_float2(x[0], x[1]); + reinterpret_cast(&smem[2 * idx])[0] = make_float2(x[0], x[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// -DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) { - smem[idx] = x[0]; -} +DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) { smem[idx] = x[0]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) { - reinterpret_cast(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]); + reinterpret_cast(&smem[4 * idx])[0] = make_float4(x[0], x[1], x[2], x[3]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) { - reinterpret_cast(&smem[2*idx])[0] = make_int2(x[0], x[1]); + reinterpret_cast(&smem[2 * idx])[0] = make_int2(x[0], x[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void zero_array(int (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = 0; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = 0; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int N > +template DEVICE_FUNCTION void zero_array(float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = 0.f; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = 0.f; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] += y[i]; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + x[i] += y[i]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] *= y[i]; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + x[i] *= y[i]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] *= scalar; - } +#pragma unroll + for (int i = 0; i < N; ++i) { + x[i] *= scalar; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N], - const float (&scale)[N], const float (&m1)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] = bias[i] + scale[i] * (x[i] - m1[i]); - } +DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N], const float (&scale)[N], const float (&m1)[N]) { +#pragma unroll + for (int i = 0; i < N; ++i) { + x[i] = bias[i] + scale[i] * (x[i] - m1[i]); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION Storage relu(Storage in) { - Storage zero = (Storage)0.f; - return (in < zero)? zero : in; + Storage zero = (Storage)0.f; + return (in < zero) ? zero : in; } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_activation(float (&x)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] = relu(x[i]); - } +#pragma unroll + for (int i = 0; i < N; ++i) { + x[i] = relu(x[i]); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int THREADS_PER_CTA > -DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, - void* params_my_data, void** params_pair_datas, int off, - const int magic, - const int sync_iters) { - // The size of a warp. - const int THREADS_PER_WARP = 32; - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of threads per pixel. - const int THREADS_PER_PIXEL = 16; - // The number of elements per ldg. - const int ELEMENTS_PER_LDG = 4; - // The number of reducing ops, each uses its own space : mean, var, dscale, dbias - const int REDUCE_OPS = 4; - // Maximum block.y supported - limited due to buffer allocation - const int MAX_BLOCK_Y = 256; - const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; - // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; - // total size of data per sync iter - const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2; - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); +template +DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, void *params_my_data, + void **params_pair_datas, int off, const int magic, const int sync_iters) { + // The size of a warp. + const int THREADS_PER_WARP = 32; + // The number of warps in a CTA. + const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + // The number of threads per pixel. + const int THREADS_PER_PIXEL = 16; + // The number of elements per ldg. + const int ELEMENTS_PER_LDG = 4; + // The number of reducing ops, each uses its own space : mean, var, dscale, dbias + const int REDUCE_OPS = 4; + // Maximum block.y supported - limited due to buffer allocation + const int MAX_BLOCK_Y = 256; + const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y; + // The warp decomposition. + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int lane_id = threadIdx.x % THREADS_PER_WARP; + // total size of data per sync iter + const int data_total = MAX_OFFSET * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2; + +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); + } + + // The warp leaders, write to SMEM. + if (lane_id < THREADS_PER_PIXEL) { + write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x); + } + + // The data is in SMEM. Do the final reduction. + __syncthreads(); + + // The 1st warp does all the work. + // We do the final reduction each half-warp sequentially reduces the final values. + if (warp_id == 0) { + read_from_smem(x, smem, threadIdx.x); + +#pragma unroll + for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + float y[ELEMENTS_PER_LDG]; + // Read the mean and variance from the other pixel. + read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP); + // Compute the updated sum. + add(x, y); } - // The warp leaders, write to SMEM. - if (lane_id < THREADS_PER_PIXEL) { - write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); } - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // The 1st warp does all the work. - // We do the final reduction each half-warp sequentially reduces the final values. - if (warp_id == 0) { - read_from_smem(x, smem, threadIdx.x); - - #pragma unroll - for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { - float y[ELEMENTS_PER_LDG]; - // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); - // Compute the updated sum. - add(x, y); - } - - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); - } + // Make sure the data was read from SMEM. + __syncwarp(); - // Make sure the data was read from SMEM. - __syncwarp(); + // Store the final values. + if (threadIdx.x < THREADS_PER_PIXEL) { + // probably could do it earlier, before sync - // Store the final values. - if (threadIdx.x < THREADS_PER_PIXEL) { - // probably could do it earlier, before sync + for (int sync_iter = 0; sync_iter < sync_iters; ++sync_iter) { + // float* params_pair_data = (reinterpret_cast(params_pair_datas))[sync_iter]; + void *params_pair_data = params_pair_datas[sync_iter]; - for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) { - //float* params_pair_data = (reinterpret_cast(params_pair_datas))[sync_iter]; - void* params_pair_data = params_pair_datas[sync_iter]; + // skip the space consumed by previous sync iterations + const int xbuf_offset = sync_iter * data_total; + // data starts after flags, but have to skip previous + const int data_offset = + xbuf_offset + off * ELEMENTS_PER_LDG * THREADS_PER_PIXEL * 2 + ELEMENTS_PER_LDG * threadIdx.x * 2; - // skip the space consumed by previous sync iterations - const int xbuf_offset = sync_iter*data_total; - // data starts after flags, but have to skip previous - const int data_offset = xbuf_offset - + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2 - + ELEMENTS_PER_LDG*threadIdx.x*2; + // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU + if (blockIdx.x == 0) { + volatile float *write_data = &((reinterpret_cast(params_pair_data))[data_offset]); - // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU - if (blockIdx.x == 0) { - volatile float * write_data = - &((reinterpret_cast(params_pair_data))[data_offset]); + // write the data to memory region to be reflected to other GPU + asm volatile("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" ::"l"(write_data), "f"(x[0]), "r"(magic), "f"(x[2]), + "r"(magic)); - // write the data to memory region to be reflected to other GPU - asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" - :: "l"(write_data) , "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic)); - - asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" - :: "l"(write_data+4) , "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic)); - } - - // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU - volatile float * read_data = - &((reinterpret_cast(params_my_data))[data_offset]); + asm volatile("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" ::"l"(write_data + 4), "f"(x[1]), "r"(magic), + "f"(x[3]), "r"(magic)); + } - float other[4]; - uint32_t other_flag_a, other_flag_b; - do { - asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" - : "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data)); - } while ((other_flag_a != magic) || (other_flag_b != magic)); + // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU + volatile float *read_data = &((reinterpret_cast(params_my_data))[data_offset]); - do { - asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" - : "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data+4)); - } while ((other_flag_a != magic) || (other_flag_b != magic)); + float other[4]; + uint32_t other_flag_a, other_flag_b; + do { + asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" + : "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) + : "l"(read_data)); + } while ((other_flag_a != magic) || (other_flag_b != magic)); - add(x, other); - } - // finally, after syncing up and accounting for partial sums from - // other GPUs as required, write the result + do { + asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" + : "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) + : "l"(read_data + 4)); + } while ((other_flag_a != magic) || (other_flag_b != magic)); + add(x, other); + } + // finally, after syncing up and accounting for partial sums from + // other GPUs as required, write the result - write_to_smem(smem, threadIdx.x, x); - } + write_to_smem(smem, threadIdx.x, x); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int THREADS_PER_CTA > +template DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { - // The size of a warp. - const int THREADS_PER_WARP = 32; - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of threads per pixel. - const int THREADS_PER_PIXEL = 8; - // The number of elements per ldg. - const int ELEMENTS_PER_LDG = 4; - // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id); + // The size of a warp. + const int THREADS_PER_WARP = 32; + // The number of warps in a CTA. + const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + // The number of threads per pixel. + const int THREADS_PER_PIXEL = 8; + // The number of elements per ldg. + const int ELEMENTS_PER_LDG = 4; + // The warp decomposition. + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int lane_id = threadIdx.x % THREADS_PER_WARP; + +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id); + } + + // The warp leaders, write to SMEM. + if (lane_id < THREADS_PER_PIXEL) { + write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x); + } + + // The data is in SMEM. Do the final reduction. + __syncthreads(); + + // The 1st warp does all the work. + // We do the final reduction each half-warp sequentially reduces the final values. + if (warp_id == 0) { + read_from_smem(x, smem, threadIdx.x); + +#pragma unroll + for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + float y[ELEMENTS_PER_LDG]; + // Read the mean and variance from the other pixel. + read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP); + // Compute the updated sum. + add(x, y); } - // The warp leaders, write to SMEM. - if (lane_id < THREADS_PER_PIXEL) { - write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); + x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id); } - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // The 1st warp does all the work. - // We do the final reduction each half-warp sequentially reduces the final values. - if (warp_id == 0) { - read_from_smem(x, smem, threadIdx.x); - - #pragma unroll - for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { - float y[ELEMENTS_PER_LDG]; - // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); - // Compute the updated sum. - add(x, y); - } + // Make sure the data was read from SMEM. + __syncwarp(); - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id); - } - - // Make sure the data was read from SMEM. - __syncwarp(); - - // Store the final values. - if (threadIdx.x < THREADS_PER_PIXEL) { - write_to_smem(smem, threadIdx.x, x); - } + // Store the final values. + if (threadIdx.x < THREADS_PER_PIXEL) { + write_to_smem(smem, threadIdx.x, x); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > +template DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - // The size of a warp. - const int THREADS_PER_WARP = 32; - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of pixels computed by a single warp. - const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL; - - // The position in the warp. - const int nhw_in_warp = nhw % PIXELS_PER_WARP; - // The C in the warp. - const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL; - - // Store the values to shared memory. - write_to_smem(smem, threadIdx.x, x); - - // Compute the parallel sums. - for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) { - // NOP. - __syncwarp(); - - // Read the running sum from the other thread. - float y[ELEMENTS_PER_LDG]; - if (nhw_in_warp < offset) { - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); - } - - // Compute the updated sum. - add(x, y); - - // NOP. - __syncwarp(); - - // Update the sum in SMEM. - if (offset > 1 && nhw_in_warp < offset) { - write_to_smem(smem, threadIdx.x, x); - } + // The size of a warp. + const int THREADS_PER_WARP = 32; + // The number of warps in a CTA. + const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + // The number of pixels computed by a single warp. + const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL; + + // The position in the warp. + const int nhw_in_warp = nhw % PIXELS_PER_WARP; + // The C in the warp. + const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL; + + // Store the values to shared memory. + write_to_smem(smem, threadIdx.x, x); + + // Compute the parallel sums. + for (int offset = PIXELS_PER_WARP / 2; offset > 0; offset /= 2) { + // NOP. + __syncwarp(); + + // Read the running sum from the other thread. + float y[ELEMENTS_PER_LDG]; + if (nhw_in_warp < offset) { + read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL); } - // The warps are done. Do the final reduction at the CTA level. - __syncthreads(); + // Compute the updated sum. + add(x, y); - // The warp leaders, write to SMEM. - const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp; - if (nhw_in_warp == 0) { - write_to_smem(smem, idx, x); - } + // NOP. + __syncwarp(); - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // Read the 1st element to prepare the work. - if (nhw < WARPS_PER_CTA/2) { - read_from_smem(x, smem, threadIdx.x); + // Update the sum in SMEM. + if (offset > 1 && nhw_in_warp < offset) { + write_to_smem(smem, threadIdx.x, x); + } + } + + // The warps are done. Do the final reduction at the CTA level. + __syncthreads(); + + // The warp leaders, write to SMEM. + const int idx = (threadIdx.x / THREADS_PER_WARP) * THREADS_PER_PIXEL + c_in_warp; + if (nhw_in_warp == 0) { + write_to_smem(smem, idx, x); + } + + // The data is in SMEM. Do the final reduction. + __syncthreads(); + + // Read the 1st element to prepare the work. + if (nhw < WARPS_PER_CTA / 2) { + read_from_smem(x, smem, threadIdx.x); + } + + // We have the running mean and running m2. Let's build the mean/var of the CTA. + for (int offset = WARPS_PER_CTA / 2; offset > 0; offset /= 2) { + // NOP. + __syncwarp(); + + // Read the mean and variance from the other pixel. + float y[ELEMENTS_PER_LDG]; + if (nhw < offset) { + read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL); } - // We have the running mean and running m2. Let's build the mean/var of the CTA. - for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) { - // NOP. - __syncwarp(); - - // Read the mean and variance from the other pixel. - float y[ELEMENTS_PER_LDG]; - if (nhw < offset) { - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); - } - - // Compute the updated sum. - add(x, y); + // Compute the updated sum. + add(x, y); - // NOP. - __syncwarp(); + // NOP. + __syncwarp(); - // Store the mean/var for the different pixels. - if (nhw < offset) { - write_to_smem(smem, threadIdx.x, x); - } + // Store the mean/var for the different pixels. + if (nhw < offset) { + write_to_smem(smem, threadIdx.x, x); } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > +template struct ParallelSums { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - parallel_sums(smem, x, nhw); - } + template + DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { + parallel_sums(smem, x, nhw); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template<> +template <> struct ParallelSums<16, 4> { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { - parallel_sums_16x2(smem, x, nhw, 0, 0, 0, 0, 0); - } - - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) { - parallel_sums_16x2(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters); - } + template + DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { + parallel_sums_16x2(smem, x, nhw, 0, 0, 0, 0, 0); + } + + template + DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void *params_my_data, void **params_pair_datas, + int off, const int magic, const unsigned int &sync_iters) { + parallel_sums_16x2(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters); + } }; -template<> +template <> struct ParallelSums<8, 4> { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { - parallel_sums_8x4(smem, x, nhw); - } + template + DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { + parallel_sums_8x4(smem, x, nhw); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline int div_up(int m, int n) { - return (m + n - 1) / n; -} +static inline int div_up(int m, int n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// // It is expected that all threads in the CTA enter this function! -DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) { +DEVICE_FUNCTION void inter_block_sync(int *gmem_retired_ctas, int expected_count, bool master) { + // Register the CTA. + if (threadIdx.x == 0) { + // Issue the membar. + __threadfence(); + // Notify that the CTA is done. + int val_to_add = 1; + if (master) { + val_to_add = -(expected_count - 1); + } + atomicAdd(gmem_retired_ctas, val_to_add); + } + + // Are all CTAs done? + if (threadIdx.x == 0) { + int retired_ctas = -1; + do { + __threadfence(); + asm volatile("ld.global.cg.b32 %0, [%1];" : "=r"(retired_ctas) : "l"(gmem_retired_ctas)); + } while (retired_ctas != 0); + } + __syncthreads(); +} - // Register the CTA. - if (threadIdx.x == 0) { - // Issue the membar. - __threadfence(); - // Notify that the CTA is done. - int val_to_add = 1; - if (master) { - val_to_add = -(expected_count - 1); - } - atomicAdd(gmem_retired_ctas, val_to_add); +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct NhwcBatchNormFwdInferenceParams { + // The input/output tensors. + uint16_t *gmem_src, *gmem_dst, *gmem_src1; + // the final mean and variance as calculated during the training process + float *gmem_mean, *gmem_var; + // The bias/scale. + float *gmem_bias, *gmem_scale; + // The dimensions. + int nhw, c; + // epsilon + float var_eps; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively +template +__global__ __launch_bounds__(THREADS_PER_CTA) void nhwc_batch_norm_fwd_inference( + NhwcBatchNormFwdInferenceParams params) { + // The number of pixels loaded in a single LDG. + const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; + // The number of C elements per CTA. + const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; + + // The start position in the NHW dimension where the CTA starts. + const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG; + // Compute the NHW coordinate of the thread in the CTA. + const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; + // thread's starting point in NHW + const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG; + + // The position in the C dimension where the CTA starts. + const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA; + // Compute the C coordinate of the thread in the CTA. + const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; + // Compute the C coordinate of the thread. + const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; + + // Is the thread working on a valid C dimension? + const int is_valid_c = thread_c < params.c; + + float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG]; + float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG]; + zero_array(mean); + zero_array(var); + zero_array(scale); + zero_array(bias); + if (is_valid_c) { + read_from_gmem(var, ¶ms.gmem_var[cta_c], thread_in_cta_c); + read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); + read_from_gmem(mean, ¶ms.gmem_mean[cta_c], thread_in_cta_c); + read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); + } + +// Update the scale with the stddev and eps. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + scale[i] *= rsqrtf(var[i] + params.var_eps); + } + + // The base pointers for reading/writing + uint16_t *const gmem_src = ¶ms.gmem_src[thread_c]; + uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; + const uint16_t *gmem_src1 = nullptr; + if (USE_ADD_RELU) { + gmem_src1 = ¶ms.gmem_src1[thread_c]; + } + + // apply BN + for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) { + float x_math[ELEMENTS_PER_LDG]; + zero_array(x_math); + if (is_valid_c) { + ldg(x_math, &gmem_src[nhw * params.c]); } - // Are all CTAs done? - if (threadIdx.x == 0) { - int retired_ctas = -1; - do { - __threadfence(); - asm volatile ("ld.global.cg.b32 %0, [%1];" - : "=r"(retired_ctas) : "l"(gmem_retired_ctas)); - } while (retired_ctas != 0); + // Normalize and apply activation function + normalize(x_math, bias, scale, mean); + if (USE_ADD_RELU) { + float x1_math[ELEMENTS_PER_LDG]; + ldg(x1_math, &gmem_src1[nhw * params.c]); + add(x_math, x1_math); + relu_activation(x_math); + } else if (USE_RELU) { + relu_activation(x_math); } - __syncthreads(); + if (is_valid_c) { + stg(&gmem_dst[nhw * params.c], x_math); + } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -struct NhwcBatchNormFwdInferenceParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dst, *gmem_src1; - // the final mean and variance as calculated during the training process - float *gmem_mean, *gmem_var; - // The bias/scale. - float *gmem_bias, *gmem_scale; - // The dimensions. - int nhw, c; - // epsilon - float var_eps; +struct NhwcBatchNormFwdParams { + // The input/output tensors. + uint16_t *gmem_src, *gmem_dst, *gmem_src1; + // The bias/scale. + float *gmem_bias, *gmem_scale; + // running mean/var (refer BN API from cudnn doc) + float *gmem_running_mean, *gmem_running_var; + // saved mean/var (refer BN API from cudnn doc) + float *gmem_saved_mean, *gmem_saved_var; + // ReLU bitmask + unsigned int *gmem_relu_bitmask; + // The dimensions. + int nhw, c; + // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. + float svar_inv_count; + // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1). + float rvar_inv_count; + // The buffer to do the reduction for mean, stddev and count. + float *gmem_sums; + // The buffer to count items in the different CTAs. + int *gmem_counts; + // The counters of retired CTAs. + int *gmem_retired_ctas; + // The epsilon to apply to the computation of the variance. + float var_eps; + // outer loop count + int outer_loops; + // exponential average factor + float exp_avg_factor; + // number of CTAs along .x dimension + int c_blks; + + void *my_data; + void *pair_datas[4]; + int magic; + int sync_iters; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int ELEMENTS_PER_LDG, - bool USE_RELU, - bool USE_ADD_RELU -> -__global__ __launch_bounds__(THREADS_PER_CTA) - void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // The start position in the NHW dimension where the CTA starts. - const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - // thread's starting point in NHW - const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG; +template +__global__ __launch_bounds__(THREADS_PER_CTA, + DESIRED_OCCUPANCY) void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) { + // The number of pixels loaded in a single LDG. + const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; + // The number of pixels computed per CTA stored in registers. + const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; + // The number of pixels computed per CTA stored in SMEM. + const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; + // The number of C elements per CTA. + const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; + + // Shared memory to do CTA-wide parallel sums. + __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; + + // Compute the NHW coordinate of the thread in the CTA. + const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; + + // The adapter for the storage. + typedef PackedStorage PackedStorage_; + // The data type for packed storage in SMEM. + typedef typename PackedStorage_::Type PackedStorageType; + // The number of elements in the packed storage. + const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; + // Registers to keep the data live for the persistent approach. + PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + + // Shared memory buffer to store the extra pixels. + extern __shared__ PackedStorageType smem_storage_packed[]; + + for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { + // The position in the NHW dimension where the CTA starts. + int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; + // The position in the NHW dimension where the CTA starts for the portion in SMEM. + int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; // The position in the C dimension where the CTA starts. - const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA; + const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; + int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; - float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG]; - float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG]; - zero_array(mean); - zero_array(var); - zero_array(scale); - zero_array(bias); - if (is_valid_c) { - read_from_gmem(var, ¶ms.gmem_var[cta_c], thread_in_cta_c); - read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); - read_from_gmem(mean, ¶ms.gmem_mean[cta_c], thread_in_cta_c); - read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); + // Clamp thread_c so that we load from valid locations even if we don't use the value + if (!is_valid_c) thread_c = params.c - 4; + + // Single pass numerically stable algorithm, see: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + // + // n = 0, mean = 0.0, M2 = 0.0 + // + // for x in data: + // n += 1 + // delta = x - mean + // mean += delta/n + // delta2 = x - mean + // M2 += delta*delta2 + // + // if n < 2: + // return float('nan') + // else: + // return M2 / (n - 1) + + // Register to store the number of elements read so far. + float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG]; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + mean[i] = 0.f; + m2[i] = 0.f; } - // Update the scale with the stddev and eps. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - scale[i] *= rsqrtf(var[i] + params.var_eps); + // The number of elements loaded by this CTA. + int cta_count = 0; + // The base pointer to load from. + const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; + + // outer loops + int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; + // Load the batch of elements. Compute the mean/var across those elements. + const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; + + if (OUTER_LOOPS_ != 1) { + // We cannot load everything to store persistently, so let's makes sure registers and + // smem are fully utilized, offset is evenly divisible by 32 + int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; + cta_nhw_regs -= offset; + cta_nhw_smem -= offset; } - // The base pointers for reading/writing - uint16_t *const gmem_src = ¶ms.gmem_src[thread_c]; - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - const uint16_t *gmem_src1 = nullptr; - if (USE_ADD_RELU) { - gmem_src1 = ¶ms.gmem_src1[thread_c]; +#pragma unroll 1 + for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { + // The nhw position. + int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; + // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! + cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) - max(nhw_regs, 0), 0); + + // Load the data and compute the local mean/sum and the variance. + if (USE_ONLINE_APPROACH) { + // Read the elements from memory. + float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + zero_array(x_storage[i]); + is_valid[i] = 0.f; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + if (loop_i == OUTER_LOOPS - 1) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + } else { + ldg(x_storage[i], &gmem_src[idx * params.c]); + } + is_valid[i] = 1.f; + } + } + +// Do the math. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float. + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + + // Update the count. + count += is_valid[i]; + // Invert the count. + float inv_count = is_valid[i] ? 1.f / count : 0.f; + +// Update the mean and m2 using deltas. +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + float delta0 = x_math[j] - mean[j]; + mean[j] += delta0 * inv_count; + float delta1 = x_math[j] - mean[j]; + m2[j] += delta0 * delta1 * is_valid[i]; + } + } + } else { +// Read the elements from memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + zero_array(x_storage[i]); + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + if (loop_i == OUTER_LOOPS - 1) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + } else { + ldg(x_storage[i], &gmem_src[idx * params.c]); + } + count += 1.f; + } + } + +// Sum the elements in registers. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float. + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + +// Update the mean and m2 using deltas. +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + mean[j] += x_math[j]; + } + } + + // Compute the mean. + float inv_count = 1.f / count; +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + mean[j] *= inv_count; + } + +// Compute the variance. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float. + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + + // Is it a valid pixel? + float is_valid = i < static_cast(count) ? 1.f : 0.f; +// Update the mean and m2 using deltas. +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid; + } + } + } } - // apply BN - for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) { - float x_math[ELEMENTS_PER_LDG]; - zero_array(x_math); - if (is_valid_c) { - ldg(x_math, &gmem_src[nhw*params.c]); - } + // The elements to load and store in SMEM. + int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; + // Load elements from SMEM, update the CTA count. + int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0); + if (pixels_in_smem > 0) { + cta_count += pixels_in_smem; + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + float is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f; + + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; + ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0) * params.c]); + + // The offset to store in SMEM. + const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + // Store in SMEM. + write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); + // Update the count. + count += is_pixel_valid; + // Invert the count. + float inv_count = is_pixel_valid ? 1.f / count : 0.f; - // Normalize and apply activation function - normalize(x_math, bias, scale, mean); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg(x1_math, &gmem_src1[nhw*params.c]); - add(x_math, x1_math); - relu_activation(x_math); - } else if (USE_RELU) { - relu_activation(x_math); - } + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); +// Update the mean and m2 using deltas. +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + float delta0 = x_math[j] - mean[j]; + mean[j] += delta0 * inv_count; + float delta1 = x_math[j] - mean[j]; + m2[j] += delta0 * delta1 * is_pixel_valid; + } + } + } - if (is_valid_c) { - stg(&gmem_dst[nhw*params.c], x_math); - } + // We scale the mean by the number of elements. It brings more stability. + float m1[ELEMENTS_PER_LDG]; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + m1[i] = mean[i] * count; } -} -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Run the parallel sum accross the CTA to get the local sum. + ParallelSums::dispatch(smem, m1, thread_in_cta_nhw); + __syncthreads(); -struct NhwcBatchNormFwdParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dst, *gmem_src1; - // The bias/scale. - float *gmem_bias, *gmem_scale; - // running mean/var (refer BN API from cudnn doc) - float *gmem_running_mean, *gmem_running_var; - // saved mean/var (refer BN API from cudnn doc) - float *gmem_saved_mean, *gmem_saved_var; - // ReLU bitmask - unsigned int *gmem_relu_bitmask; - // The dimensions. - int nhw, c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - float svar_inv_count; - // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1). - float rvar_inv_count; - // The buffer to do the reduction for mean, stddev and count. - float *gmem_sums; - // The buffer to count items in the different CTAs. - int *gmem_counts; - // The counters of retired CTAs. - int *gmem_retired_ctas; - // The epsilon to apply to the computation of the variance. - float var_eps; - // outer loop count - int outer_loops; - // exponential average factor - float exp_avg_factor; - // number of CTAs along .x dimension - int c_blks; - - void* my_data; - void* pair_datas[4]; - int magic; - int sync_iters; -}; + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(m1, smem, thread_in_cta_c); + __syncthreads(); -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Adjust the variance. + float inv_cta_count = 1.f / static_cast(cta_count); +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + float mean_diff = m1[i] * inv_cta_count - mean[i]; + m2[i] = m2[i] + mean_diff * mean_diff * count; + } -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - bool USE_RELU, - bool USE_ADD_RELU, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; + // Run the parallel sum accross the CTA to get the local adjusted variance. + ParallelSums::dispatch(smem, m2, thread_in_cta_nhw); - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; + // The workspace in global memory is distributed across the different CTA. + int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - // Clamp thread_c so that we load from valid locations even if we don't use the value - if (!is_valid_c) - thread_c = params.c - 4; - - // Single pass numerically stable algorithm, see: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm - // - // n = 0, mean = 0.0, M2 = 0.0 - // - // for x in data: - // n += 1 - // delta = x - mean - // mean += delta/n - // delta2 = x - mean - // M2 += delta*delta2 - // - // if n < 2: - // return float('nan') - // else: - // return M2 / (n - 1) - - // Register to store the number of elements read so far. - float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean[i] = 0.f; - m2[i] = 0.f; - } + // Write the data for the CTA to global memory. + float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; + if (threadIdx.x < THREADS_PER_PIXEL) { + const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; + write_to_gmem(&gmem_sums[0], idx, m1); + write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, m2); + } - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointer to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute the mean/var across those elements. - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized, offset is evenly divisible by 32 - int offset = (pixels_per_iteration * OUTER_LOOPS + - PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; - cta_nhw_regs -= offset; - cta_nhw_smem -= offset; - } + // The memory location to store the number of pixels per CTA. + int *gmem_counts = ¶ms.gmem_counts[c_blk_index * gridDim.x]; + if (threadIdx.x == 0) { + gmem_counts[blockIdx.x] = cta_count; + } - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) - - max(nhw_regs, 0), 0); - - // Load the data and compute the local mean/sum and the variance. - if (USE_ONLINE_APPROACH) { - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - } - is_valid[i] = 1.f; - } - } - - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - float delta0 = x_math[j] - mean[j]; - mean[j] += delta0 * inv_count; - float delta1 = x_math[j] - mean[j]; - m2[j] += delta0 * delta1 * is_valid[i]; - } - } - } else { - // Read the elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - } - count += 1.f; - } - } - - // Sum the elements in registers. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - mean[j] += x_math[j]; - } - } - - // Compute the mean. - float inv_count = 1.f / count; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - mean[j] *= inv_count; - } - - // Compute the variance. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Is it a valid pixel? - float is_valid = i < static_cast(count) ? 1.f : 0.f; - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid; - } - } - } - } + // Read the bias and scale. + float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG]; + if (is_valid_c) { + read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); + read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); + } - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - float is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f; - - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; - ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]); - - // The offset to store in SMEM. - const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - float delta0 = x_math[j] - mean[j]; - mean[j] += delta0 * inv_count; - float delta1 = x_math[j] - mean[j]; - m2[j] += delta0 * delta1 * is_pixel_valid; - } - } - } + // The counters to count how many CTAs have retired at this point. + // A given cta uses the same counter every other time through the outer loop. + int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; + inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - // We scale the mean by the number of elements. It brings more stability. - float m1[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = mean[i] * count; - } +// Reset the mean to compute the global mean. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + m1[i] = 0.f; + } - // Run the parallel sum accross the CTA to get the local sum. - ParallelSums::dispatch( - smem, m1, thread_in_cta_nhw); - __syncthreads(); - - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(m1, smem, thread_in_cta_c); - __syncthreads(); - - // Adjust the variance. - float inv_cta_count = 1.f / static_cast(cta_count); - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - float mean_diff = m1[i]*inv_cta_count - mean[i]; - m2[i] = m2[i] + mean_diff * mean_diff * count; - } +// Build the global mean. +#pragma unroll 1 + for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { + float tmp[ELEMENTS_PER_LDG]; + read_from_gmem(tmp, gmem_sums, idx); + add(m1, tmp); + } - // Run the parallel sum accross the CTA to get the local adjusted variance. - ParallelSums::dispatch( - smem, m2, thread_in_cta_nhw); + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 3, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, m1, thread_in_cta_nhw); + } + __syncthreads(); - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(m1, smem, thread_in_cta_c); + __syncthreads(); - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, m1); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2); - } +// Normalize the mean. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + m1[i] = m1[i] * params.svar_inv_count; + } - // The memory location to store the number of pixels per CTA. - int *gmem_counts = ¶ms.gmem_counts[c_blk_index*gridDim.x]; - if (threadIdx.x == 0) { - gmem_counts[blockIdx.x] = cta_count; - } +// Reset the variance. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + m2[i] = 0.f; + } - // Read the bias and scale. - float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG]; - if (is_valid_c) { - read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); - read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); - } + // for add+relu fusion + const uint16_t *gmem_src1 = nullptr; + if (USE_ADD_RELU) { + gmem_src1 = ¶ms.gmem_src1[thread_c]; + } - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); +// Build the global variance. +#pragma unroll 1 + for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { + // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration. + float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG]; + read_from_gmem(tmp_mean, &gmem_sums[0], idx); + read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx); + + // Read the number of pixels visited by a given CTA. + cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]); + + // Compute the diff to update the variance. + float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast(cta_count); +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + mean_diff[i] = m1[i] - tmp_mean[i] * inv_cta_count; + } + +// Update the variance. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + m2[i] += tmp_var[i] + mean_diff[i] * mean_diff[i] * static_cast(cta_count); + } + } - // Reset the mean to compute the global mean. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = 0.f; - } + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 2, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, m2, thread_in_cta_nhw); + } + __syncthreads(); - // Build the global mean. - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp[ELEMENTS_PER_LDG]; - read_from_gmem(tmp, gmem_sums, idx); - add(m1, tmp); - } + read_from_smem(m2, smem, thread_in_cta_c); - if (params.sync_iters>0) - { - ParallelSums::dispatchX( - smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, m1, thread_in_cta_nhw); - } - __syncthreads(); + // Finalize the stddev. + // becasue saved var and running var may have different denominator, we don't do it here + // scale_(m2, inv_count); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(m1, smem, thread_in_cta_c); - __syncthreads(); + // store the saved mean/var + float svarinv[ELEMENTS_PER_LDG]; + bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps); + } + if (is_valid_for_saving) { + write_to_gmem(params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG, m1); + write_to_gmem(params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG, svarinv); + } - // Normalize the mean. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = m1[i] * params.svar_inv_count; - } + // store the running mean/var + float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG]; + zero_array(rmean); + zero_array(rvar); + if (params.exp_avg_factor != 1.f && is_valid_for_saving) { + read_from_gmem(rmean, params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG); + read_from_gmem(rvar, params.gmem_running_var, thread_c / ELEMENTS_PER_LDG); + } +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + params.exp_avg_factor * m1[i]; + rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + params.exp_avg_factor * (m2[i] * params.rvar_inv_count); + } + if (is_valid_for_saving) { + write_to_gmem(params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG, rmean); + write_to_gmem(params.gmem_running_var, thread_c / ELEMENTS_PER_LDG, rvar); + } - // Reset the variance. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m2[i] = 0.f; - } + // Update the scale with the stddev and eps. + multiply(scale, svarinv); - // for add+relu fusion - const uint16_t *gmem_src1 = nullptr; - if (USE_ADD_RELU) { - gmem_src1 = ¶ms.gmem_src1[thread_c]; - } + // The base pointer to write to. + uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - // Build the global variance. - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration. - float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG]; - read_from_gmem(tmp_mean, &gmem_sums[ 0], idx); - read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx); - - // Read the number of pixels visited by a given CTA. - cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]); - - // Compute the diff to update the variance. - float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast(cta_count); - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count; - } + unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index; + +// Store the elements in registers. +#pragma unroll 1 + for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { + // The value for nhw. + int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; + +// Normalize the elements and write to memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); + const bool is_valid = is_valid_nhw && is_valid_c; + // Convert to float. + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); - // Update the variance. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast(cta_count); + // Normalize and apply activation function + normalize(x_math, bias, scale, m1); + if (USE_ADD_RELU) { + float x1_math[ELEMENTS_PER_LDG]; + ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]); + add(x_math, x1_math); + unsigned int relu_mask; + int lane_id = threadIdx.x & 31; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + bool rectified = x_math[i] < 0.0F; + unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); + if (lane_id == i) { + // Thread 0 remembers the relu_mask from the first time through this + // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last. + relu_mask = local_relu_mask; } + if (rectified) { + x_math[i] = 0.0F; + } + } + if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { + gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; + } + } else if (USE_RELU) { + relu_activation(x_math); } - if (params.sync_iters>0) - { - ParallelSums::dispatchX( - smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, m2, thread_in_cta_nhw); + // Write back. + if (is_valid) { + stg_stream(&gmem_dst[idx * params.c], x_math); } - __syncthreads(); + } - read_from_smem(m2, smem, thread_in_cta_c); + // The next value of nhw. + out_nhw -= pixels_per_iteration; - // Finalize the stddev. - // becasue saved var and running var may have different denominator, we don't do it here - // scale_(m2, inv_count); - - // store the saved mean/var - float svarinv[ELEMENTS_PER_LDG]; - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps); - } - if (is_valid_for_saving) { - write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1); - write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv); +// Read the next elements from memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); } + } + } - // store the running mean/var - float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG]; - zero_array(rmean); - zero_array(rvar); - if (params.exp_avg_factor != 1.f && is_valid_for_saving) { - read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG); - read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG); - } - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + \ - params.exp_avg_factor * m1[i]; - rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + \ - params.exp_avg_factor * (m2[i] * params.rvar_inv_count); - } - if (is_valid_for_saving) { - write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean); - write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar); - } + // Normalize the elements from SMEM and write them out. + if (pixels_in_smem > 0) { +#pragma unroll 2 + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); + const bool is_valid = is_valid_nhw && is_valid_c; + + // Read from SMEM. + const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; + read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); + float x_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); - // Update the scale with the stddev and eps. - multiply(scale, svarinv); - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask + - ((params.nhw + 31) & ~31) * 2 * c_blk_index; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_valid = is_valid_nhw && is_valid_c; - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Normalize and apply activation function - normalize(x_math, bias, scale, m1); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); - add(x_math, x1_math); - unsigned int relu_mask; - int lane_id = threadIdx.x & 31; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - bool rectified = x_math[i] < 0.0F; - unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); - if (lane_id == i) { - // Thread 0 remembers the relu_mask from the first time through this - // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last. - relu_mask = local_relu_mask; - } - if (rectified) { - x_math[i] = 0.0F; - } - } - if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; - } - } else if (USE_RELU) { - relu_activation(x_math); - } - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], x_math); - } + // Normalize and apply activation function + normalize(x_math, bias, scale, m1); + if (USE_ADD_RELU) { + float x1_math[ELEMENTS_PER_LDG]; + ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]); + add(x_math, x1_math); + unsigned int relu_mask; + int lane_id = threadIdx.x & 31; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + bool rectified = x_math[i] < 0.0F; + unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); + if (lane_id == i) { + relu_mask = local_relu_mask; } - - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } + if (rectified) { + x_math[i] = 0.0F; } + } + if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { + gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; + } + } else if (USE_RELU) { + relu_activation(x_math); } - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - #pragma unroll 2 - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_valid = is_valid_nhw && is_valid_c; - - // Read from SMEM. - const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - - // Normalize and apply activation function - normalize(x_math, bias, scale, m1); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); - add(x_math, x1_math); - unsigned int relu_mask; - int lane_id = threadIdx.x & 31; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - bool rectified = x_math[i] < 0.0F; - unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); - if (lane_id == i) { - relu_mask = local_relu_mask; - } - if (rectified) { - x_math[i] = 0.0F; - } - } - if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; - } - } else if (USE_RELU) { - relu_activation(x_math); - } - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], x_math); - } - } + // Write back. + if (is_valid) { + stg_stream(&gmem_dst[idx * params.c], x_math); } - // We're about to start on the next c-blk. Needed? - __syncthreads(); + } } + // We're about to start on the next c-blk. Needed? + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// struct NhwcBatchNormBwdParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1; - // dscale/dbias - float *gmem_dscale, *gmem_dbias; - // The scale and bias. - float *gmem_scale, *gmem_bias; - // The mean/inv-var saved from fwd pass - float *gmem_saved_mean, *gmem_saved_var; - // ReLU bitmask - unsigned int *gmem_relu_bitmask; - // The dimensions. - int nhw, c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - float svar_inv_count; - // The buffer to do the reduction for dscale and dbias - float *gmem_sums; - // The counters of retired CTAs. - int *gmem_retired_ctas; - // outer loop count - int outer_loops; - // number of CTAs along .x dimension - int c_blks; - - void* my_data; - void* pair_datas[4]; - int magic; - int sync_iters; - float wgrad_coeff; + // The input/output tensors. + uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1; + // dscale/dbias + float *gmem_dscale, *gmem_dbias; + // The scale and bias. + float *gmem_scale, *gmem_bias; + // The mean/inv-var saved from fwd pass + float *gmem_saved_mean, *gmem_saved_var; + // ReLU bitmask + unsigned int *gmem_relu_bitmask; + // The dimensions. + int nhw, c; + // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. + float svar_inv_count; + // The buffer to do the reduction for dscale and dbias + float *gmem_sums; + // The counters of retired CTAs. + int *gmem_retired_ctas; + // outer loop count + int outer_loops; + // number of CTAs along .x dimension + int c_blks; + + void *my_data; + void *pair_datas[4]; + int magic; + int sync_iters; + float wgrad_coeff; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template -DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N], - const float (&mean_var_scale_bias)[N], +DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N], const float (&var_scale)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; - if ((y <= 0.f) && valid_data) { - dy[j] = 0.f; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; + if ((y <= 0.f) && valid_data) { + dy[j] = 0.f; } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if ((y[j] <= 0.f) && valid_data) { - dy[j] = 0.f; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + if ((y[j] <= 0.f) && valid_data) { + dy[j] = 0.f; } + } } template DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if (rectified[j] && valid_data) { - dy[j] = 0.f; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + if (rectified[j] && valid_data) { + dy[j] = 0.f; } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], - const float (&x)[N], - const float (&mean_var_scale_bias)[N], +DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N], const float (&var_scale)[N]) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; - if (y <= 0.f) { - dy[j] = 0.f; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; + if (y <= 0.f) { + dy[j] = 0.f; } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if (y[j] <= 0.f) { - dy[j] = 0.f; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + if (y[j] <= 0.f) { + dy[j] = 0.f; } + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N], - const float (&dy)[N], const float (&x)[N], +DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N], const float (&dy)[N], const float (&x)[N], const float (&mean)[N], float inv_count) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float delta0 = dy[j] - dbias[j]; - dbias[j] += delta0 * inv_count; - delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j]; - dscale[j] += delta0 * inv_count; - } +#pragma unroll + for (int j = 0; j < N; ++j) { + float delta0 = dy[j] - dbias[j]; + dbias[j] += delta0 * inv_count; + delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j]; + dscale[j] += delta0 * inv_count; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N], - const float (&var)[N], const float (&x)[N], const float (&mean)[N], - const float (&dscale)[N], const float (&dbias)[N], float inv_count) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float tmp1 = dy[j] - (dbias[j]* inv_count); - float tmp2 = dscale[j] * inv_count; - float tmp3 = x[j] - mean[j]; - dx[j] = var[j] * (tmp1 - (tmp2 * tmp3)); - } +DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N], const float (&var)[N], const float (&x)[N], + const float (&mean)[N], const float (&dscale)[N], const float (&dbias)[N], + float inv_count) { +#pragma unroll + for (int j = 0; j < N; ++j) { + float tmp1 = dy[j] - (dbias[j] * inv_count); + float tmp2 = dscale[j] * inv_count; + float tmp3 = x[j] - mean[j]; + dx[j] = var[j] * (tmp1 - (tmp2 * tmp3)); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - // Registers to store the mean used for entire duration - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } +template +__global__ __launch_bounds__(THREADS_PER_CTA, + DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) { + // The number of pixels loaded in a single LDG. + const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; + // The number of pixels computed per CTA stored in registers. + const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; + // The number of pixels computed per CTA stored in SMEM. + const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; + // The number of C elements per CTA. + const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; + + // Shared memory to do CTA-wide parallel sums. + __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; + + // The adapter for the storage. + typedef PackedStorage PackedStorage_; + // The data type for packed storage in SMEM. + typedef typename PackedStorage_::Type PackedStorageType; + // The number of elements in the packed storage. + const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; + // Registers to keep the data live for the persistent approach. + PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + + // Shared memory buffer to store the extra pixels. + extern __shared__ PackedStorageType smem_storage_packed[]; + + for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { + // The position in the NHW dimension where the CTA starts. + int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; + // The position in the NHW dimension where the CTA starts for the portion in SMEM. + int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; + // Compute the NHW coordinate of the thread in the CTA. + const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized - int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - - PIXELS_PER_CTA_IN_SMEM * gridDim.x; - cta_nhw_regs += offset; - cta_nhw_smem += offset; - } + // The position in the C dimension where the CTA starts. + const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; + // Compute the C coordinate of the thread in the CTA. + const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; + // Compute the C coordinate of the thread. + const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - } + // Is the thread working on a valid C dimension? + const int is_valid_c = thread_c < params.c; - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); + // Registers to store the mean used for entire duration + float mean[ELEMENTS_PER_LDG]; + zero_array(mean); + if (is_valid_c) { + read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); + } - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; + // accumulation related registers + float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; + zero_array(dscale); + zero_array(dbias); + + // The number of elements loaded by this CTA. + int cta_count = 0; + // The base pointers to load from. + const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; + const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; + + // outer loops + int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; + // Load the batch of elements. Compute sum across them + const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; + + if (OUTER_LOOPS_ != 1) { + // We cannot load everything to store persistently, so let's makes sure registers and + // smem are fully utilized + int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x; + cta_nhw_regs += offset; + cta_nhw_smem += offset; + } - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } +#pragma unroll 1 + for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { + // The nhw position. + int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; + // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! + cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); + + // Read the elements from memory. + float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + zero_array(x_storage[i]); + zero_array(dy_storage[i]); + is_valid[i] = 0.f; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + if (loop_i == OUTER_LOOPS - 1) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); + } else { + ldg(x_storage[i], &gmem_src[idx * params.c]); + ldg(dy_storage[i], &gmem_dy[idx * params.c]); + } + is_valid[i] = 1.f; + } + } + +// Do the math. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float and update + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + + // Update the count. + count += is_valid[i]; + // Invert the count. + float inv_count = is_valid[i] ? 1.f / count : 0.f; + + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + } + } - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - bool is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c); - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } + // The elements to load and store in SMEM. + int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; + // Load elements from SMEM, update the CTA count. + int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); + if (pixels_in_smem > 0) { + cta_count += pixels_in_smem; + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c); + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + zero_array(x_storage_local); + zero_array(dy_storage_local); + if (is_pixel_valid) { + ldg_stream(x_storage_local, &gmem_src[idx * params.c]); + ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); + } + + // The offset to store in SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + // Store in SMEM. + write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); + // Update the count. + count += is_pixel_valid; + // Invert the count. + float inv_count = is_pixel_valid ? 1.f / count : 0.f; + + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + } + } - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } +// We scale the mean by the number of elements. It brings more stability. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dbias[i] *= count; + dscale[i] *= count; + } - // dscale parallel sum - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } + // dscale parallel sum + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } + // dbias parallel sum + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); + __syncthreads(); - // dscale parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - } + // The workspace in global memory is distributed across the different CTA. + int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; + // Write the data for the CTA to global memory. + float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; + if (threadIdx.x < THREADS_PER_PIXEL) { + const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; + write_to_gmem(&gmem_sums[0], idx, dscale); + write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); + } - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - } + // The counters to count how many CTAs have retired at this point. + // A given cta uses the same counter every other time through the outer loop. + int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; + inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); + + // Reset the accumulators for global summation + zero_array(dscale); + zero_array(dbias); + +// Build the global accumulation +#pragma unroll 1 + for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { + float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; + read_from_gmem(tmp1, gmem_sums, idx); + read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); + +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dscale[i] += tmp1[i]; + dbias[i] += tmp2[i]; + } + } - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); + // dscale parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + } - // inv-var - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - // Normalize the dscale. - multiply(dscale, var); + // dbias parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + } - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); - // scale - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } + // inv-var + float var[ELEMENTS_PER_LDG]; + zero_array(var); + if (is_valid_c) { + read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); + } - // Further normalize the dscale to be used in dx calculation - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } + // Normalize the dscale. + multiply(dscale, var); + + // store dscale/dbias + bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; + if (is_valid_for_saving) { + if (params.sync_iters > 0) { + scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); + scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); + } else { + write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); + write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); + } + } - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } - } - } + // scale + float scale[ELEMENTS_PER_LDG]; + zero_array(scale); + if (is_valid_c) { + read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); + } - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); + // Further normalize the dscale to be used in dx calculation + multiply(dscale, var); + // scale the inv-var as well, afterwards + multiply(var, scale); + + // inverse count + float inv_count = params.svar_inv_count; + + // The base pointer to write to. + uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; + +// Store the elements in registers. +#pragma unroll 1 + for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { + // The value for nhw. + int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; + +// Normalize the elements and write to memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float. + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + + // The next value of nhw. + out_nhw -= pixels_per_iteration; + +// Read the next elements from memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); + } + } } + + // Normalize the elements from SMEM and write them out. + if (pixels_in_smem > 0) { + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; + if (is_valid) { + // Read from SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + } + // We're about to start on the next c-blk. Needed? + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - - // Registers to store the mean/var/scale/bias used for the entire duration - // Register usage optimizations: - // 1. Can combine bias - (mean * var * scale) into a single register - // 2. Can combine var * scale into a single register - float varscale[ELEMENTS_PER_LDG]; - zero_array(varscale); - if (is_valid_c) { - read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - float tmp[ELEMENTS_PER_LDG]; - zero_array(tmp); - if (is_valid_c) { - read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(varscale, tmp); - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } - zero_array(tmp); - if (is_valid_c) { - read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG); - } - float mean_var_scale_bias[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]); - } +template +__global__ __launch_bounds__(THREADS_PER_CTA, + DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) { + // The number of pixels loaded in a single LDG. + const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; + // The number of pixels computed per CTA stored in registers. + const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; + // The number of pixels computed per CTA stored in SMEM. + const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; + // The number of C elements per CTA. + const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; + + // Shared memory to do CTA-wide parallel sums. + __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; + + // The adapter for the storage. + typedef PackedStorage PackedStorage_; + // The data type for packed storage in SMEM. + typedef typename PackedStorage_::Type PackedStorageType; + // The number of elements in the packed storage. + const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; + // Registers to keep the data live for the persistent approach. + PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + + // Shared memory buffer to store the extra pixels. + extern __shared__ PackedStorageType smem_storage_packed[]; + + for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { + // The position in the NHW dimension where the CTA starts. + int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; + // The position in the NHW dimension where the CTA starts for the portion in SMEM. + int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; + // Compute the NHW coordinate of the thread in the CTA. + const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized - int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - - PIXELS_PER_CTA_IN_SMEM * gridDim.x; - cta_nhw_regs += offset; - cta_nhw_smem += offset; - } + // The position in the C dimension where the CTA starts. + const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; + // Compute the C coordinate of the thread in the CTA. + const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; + // Compute the C coordinate of the thread. + const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - } + // Is the thread working on a valid C dimension? + const int is_valid_c = thread_c < params.c; - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } + // Registers to store the mean/var/scale/bias used for the entire duration + // Register usage optimizations: + // 1. Can combine bias - (mean * var * scale) into a single register + // 2. Can combine var * scale into a single register + float varscale[ELEMENTS_PER_LDG]; + zero_array(varscale); + if (is_valid_c) { + read_from_gmem(varscale, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); + } + float tmp[ELEMENTS_PER_LDG]; + zero_array(tmp); + if (is_valid_c) { + read_from_gmem(tmp, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); + } + multiply(varscale, tmp); + float mean[ELEMENTS_PER_LDG]; + zero_array(mean); + if (is_valid_c) { + read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); + } + zero_array(tmp); + if (is_valid_c) { + read_from_gmem(tmp, params.gmem_bias, thread_c / ELEMENTS_PER_LDG); + } + float mean_var_scale_bias[ELEMENTS_PER_LDG]; +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]); + } - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - bool is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c); - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } + // accumulation related registers + float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; + zero_array(dscale); + zero_array(dbias); + + // The number of elements loaded by this CTA. + int cta_count = 0; + // The base pointers to load from. + const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; + const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; + + // outer loops + int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; + // Load the batch of elements. Compute sum across them + const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; + + if (OUTER_LOOPS_ != 1) { + // We cannot load everything to store persistently, so let's makes sure registers and + // smem are fully utilized + int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x; + cta_nhw_regs += offset; + cta_nhw_smem += offset; + } - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } +#pragma unroll 1 + for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { + // The nhw position. + int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; + // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! + cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); + + // Read the elements from memory. + float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + zero_array(x_storage[i]); + zero_array(dy_storage[i]); + is_valid[i] = 0.f; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + if (loop_i == OUTER_LOOPS - 1) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); + } else { + ldg(x_storage[i], &gmem_src[idx * params.c]); + ldg(dy_storage[i], &gmem_dy[idx * params.c]); + } + is_valid[i] = 1.f; + } + } + +// Do the math. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float and update + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + + // Update the count. + count += is_valid[i]; + // Invert the count. + float inv_count = is_valid[i] ? 1.f / count : 0.f; + + relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]); + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + } + } - // dscale parallel sum - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } + // The elements to load and store in SMEM. + int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; + // Load elements from SMEM, update the CTA count. + int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); + if (pixels_in_smem > 0) { + cta_count += pixels_in_smem; + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c); + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + zero_array(x_storage_local); + zero_array(dy_storage_local); + if (is_pixel_valid) { + ldg_stream(x_storage_local, &gmem_src[idx * params.c]); + ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); + } + + // The offset to store in SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + // Store in SMEM. + write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); + // Update the count. + count += is_pixel_valid; + // Invert the count. + float inv_count = is_pixel_valid ? 1.f / count : 0.f; + + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + + relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid); + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + } + } - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } +// We scale the mean by the number of elements. It brings more stability. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dbias[i] *= count; + dscale[i] *= count; + } - // dscale parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - } + // dscale parallel sum + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - } + // dbias parallel sum + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); + __syncthreads(); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); + // The workspace in global memory is distributed across the different CTA. + int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; + // Write the data for the CTA to global memory. + float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; + if (threadIdx.x < THREADS_PER_PIXEL) { + const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; + write_to_gmem(&gmem_sums[0], idx, dscale); + write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); + } - // Normalize the dscale. - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } + // The counters to count how many CTAs have retired at this point. + // A given cta uses the same counter every other time through the outer loop. + int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; + inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); + + // Reset the accumulators for global summation + zero_array(dscale); + zero_array(dbias); + +// Build the global accumulation +#pragma unroll 1 + for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { + float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; + read_from_gmem(tmp1, gmem_sums, idx); + read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); + +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dscale[i] += tmp1[i]; + dbias[i] += tmp2[i]; + } + } - // Further normalize the dscale to be used in dx calculation - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } + // dscale parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + } - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } - } - } + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); + // dbias parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + } + + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); + + // Normalize the dscale. + float var[ELEMENTS_PER_LDG]; + zero_array(var); + if (is_valid_c) { + read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); + } + multiply(dscale, var); + + // store dscale/dbias + bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; + if (is_valid_for_saving) { + if (params.sync_iters > 0) { + scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); + scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); + } else { + write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); + write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); + } + } + + // Further normalize the dscale to be used in dx calculation + float scale[ELEMENTS_PER_LDG]; + zero_array(scale); + if (is_valid_c) { + read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); } + multiply(dscale, var); + // scale the inv-var as well, afterwards + multiply(var, scale); + + // inverse count + float inv_count = params.svar_inv_count; + + // The base pointer to write to. + uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; + +// Store the elements in registers. +#pragma unroll 1 + for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { + // The value for nhw. + int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; + +// Normalize the elements and write to memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + // Convert to float. + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + + // The next value of nhw. + out_nhw -= pixels_per_iteration; + +// Read the next elements from memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); + } + } + } + + // Normalize the elements from SMEM and write them out. + if (pixels_in_smem > 0) { + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; + if (is_valid) { + // Read from SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + } + // We're about to start on the next c-blk. Needed? + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } +template +__global__ __launch_bounds__(THREADS_PER_CTA, + DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) { + // The number of pixels loaded in a single LDG. + const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; + // The number of pixels computed per CTA stored in registers. + const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; + // The number of pixels computed per CTA stored in SMEM. + const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; + // The number of C elements per CTA. + const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; + + // Shared memory to do CTA-wide parallel sums. + __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; + + // The adapter for the storage. + typedef PackedStorage PackedStorage_; + // The data type for packed storage in SMEM. + typedef typename PackedStorage_::Type PackedStorageType; + // The number of elements in the packed storage. + const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; + // Registers to keep the data live for the persistent approach. + PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; + + // Shared memory buffer to store the extra pixels. + extern __shared__ PackedStorageType smem_storage_packed[]; + + for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { + // The position in the NHW dimension where the CTA starts. + int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; + // The position in the NHW dimension where the CTA starts for the portion in SMEM. + int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; + // Compute the NHW coordinate of the thread in the CTA. + const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - uint16_t *gmem_dst1 = ¶ms.gmem_dst1[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized, offset is evenly divisible by 32 - int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - - params.nhw) & ~31; - cta_nhw_regs -= offset; - cta_nhw_smem -= offset; - } + // The position in the C dimension where the CTA starts. + const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; + // Compute the C coordinate of the thread in the CTA. + const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; + // Compute the C coordinate of the thread. + const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; - const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask + - ((params.nhw + 31) & ~31) * 2 * c_blk_index; - - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - - int lane_id = threadIdx.x & 31; - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - if (is_valid_nhw) { - if (is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - - if (lane_id < ELEMENTS_PER_LDG) { - relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id]; - } - } - } + // Is the thread working on a valid C dimension? + const int is_valid_c = thread_c < params.c; - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - bool rectified[ELEMENTS_PER_LDG]; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) & - (1U << lane_id)) != 0); - } - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - relu_bwd(dy_math, rectified, is_valid[i]); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - - // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version - from_float(dy_storage[i], dy_math); - - // dZ for elementwise add - if (is_valid[i]) { - if (loop_i == OUTER_LOOPS - 1) { - stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]); - } else { - stg(&gmem_dst1[idx*params.c], dy_storage[i]); - } - } - } - } + float mean[ELEMENTS_PER_LDG]; + zero_array(mean); + if (is_valid_c) { + read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); + } + + // accumulation related registers + float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; + zero_array(dscale); + zero_array(dbias); + + // The number of elements loaded by this CTA. + int cta_count = 0; + // The base pointers to load from. + const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; + const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; + uint16_t *gmem_dst1 = ¶ms.gmem_dst1[thread_c]; + + // outer loops + int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; + // Load the batch of elements. Compute sum across them + const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; + + if (OUTER_LOOPS_ != 1) { + // We cannot load everything to store persistently, so let's makes sure registers and + // smem are fully utilized, offset is evenly divisible by 32 + int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; + cta_nhw_regs -= offset; + cta_nhw_smem -= offset; + } - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_pixel_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - unsigned int relu_mask; - int lane_id = threadIdx.x & 31; - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid_nhw) { - if (is_valid_c) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - if (lane_id < ELEMENTS_PER_LDG) { - relu_mask = gmem_relu_bitmask[idx * 2 + lane_id]; - } - } - bool rectified[ELEMENTS_PER_LDG]; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) & - (1U << lane_id)) != 0); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - relu_bwd(dy_math, rectified, is_pixel_valid); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - - from_float(dy_storage_local, dy_math); - // dZ for elementwise add - if (is_pixel_valid) { - stg_stream(&gmem_dst1[idx*params.c], dy_storage_local); - } - // only store the 'relu-dgrad'ed version! - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); + const unsigned int *const gmem_relu_bitmask = + params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index; + +#pragma unroll 1 + for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { + // The nhw position. + int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; + // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! + cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); + + int lane_id = threadIdx.x & 31; + + // Read the elements from memory. + float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; + unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + zero_array(x_storage[i]); + zero_array(dy_storage[i]); + is_valid[i] = 0.f; + const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); + if (is_valid_nhw) { + if (is_valid_c) { + if (loop_i == OUTER_LOOPS - 1) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); + } else { + ldg(x_storage[i], &gmem_src[idx * params.c]); + ldg(dy_storage[i], &gmem_dy[idx * params.c]); } - } + is_valid[i] = 1.f; + } + + if (lane_id < ELEMENTS_PER_LDG) { + relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id]; + } + } + } + +// Do the math. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; + // Convert to float and update + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + bool rectified[ELEMENTS_PER_LDG]; +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) & (1U << lane_id)) != 0); + } + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + + // Update the count. + count += is_valid[i]; + // Invert the count. + float inv_count = is_valid[i] ? 1.f / count : 0.f; + + relu_bwd(dy_math, rectified, is_valid[i]); + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + + // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version + from_float(dy_storage[i], dy_math); + + // dZ for elementwise add + if (is_valid[i]) { + if (loop_i == OUTER_LOOPS - 1) { + stg_stream(&gmem_dst1[idx * params.c], dy_storage[i]); + } else { + stg(&gmem_dst1[idx * params.c], dy_storage[i]); + } + } + } + } - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } + // The elements to load and store in SMEM. + int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; + // Load elements from SMEM, update the CTA count. + int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); + if (pixels_in_smem > 0) { + cta_count += pixels_in_smem; + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_pixel_valid_nhw = static_cast(idx) < static_cast(params.nhw); + const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c; + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + unsigned int relu_mask; + int lane_id = threadIdx.x & 31; + zero_array(x_storage_local); + zero_array(dy_storage_local); + if (is_pixel_valid_nhw) { + if (is_valid_c) { + ldg_stream(x_storage_local, &gmem_src[idx * params.c]); + ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); + } + if (lane_id < ELEMENTS_PER_LDG) { + relu_mask = gmem_relu_bitmask[idx * 2 + lane_id]; + } + } + bool rectified[ELEMENTS_PER_LDG]; +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { + rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) & (1U << lane_id)) != 0); + } + + // The offset to store in SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + // Store in SMEM. + write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + // Update the count. + count += is_pixel_valid; + // Invert the count. + float inv_count = is_pixel_valid ? 1.f / count : 0.f; + + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + + relu_bwd(dy_math, rectified, is_pixel_valid); + bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); + + from_float(dy_storage_local, dy_math); + // dZ for elementwise add + if (is_pixel_valid) { + stg_stream(&gmem_dst1[idx * params.c], dy_storage_local); + } + // only store the 'relu-dgrad'ed version! + write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); + } + } - // dscale parallel sum - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } +// We scale the mean by the number of elements. It brings more stability. +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dbias[i] *= count; + dscale[i] *= count; + } - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } + // dscale parallel sum + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - // dscale parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dscale, thread_in_cta_nhw); - } + // dbias parallel sum + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); + __syncthreads(); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { - ParallelSums::dispatch( - smem, dbias, thread_in_cta_nhw); - } + // The workspace in global memory is distributed across the different CTA. + int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; + // Write the data for the CTA to global memory. + float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; + if (threadIdx.x < THREADS_PER_PIXEL) { + const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; + write_to_gmem(&gmem_sums[0], idx, dscale); + write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); + } - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); + // The counters to count how many CTAs have retired at this point. + // A given cta uses the same counter every other time through the outer loop. + int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; + inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); + + // Reset the accumulators for global summation + zero_array(dscale); + zero_array(dbias); + +// Build the global accumulation +#pragma unroll 1 + for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { + float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; + read_from_gmem(tmp1, gmem_sums, idx); + read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); + +#pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + dscale[i] += tmp1[i]; + dbias[i] += tmp2[i]; + } + } - // Normalize the dscale. - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } + // dscale parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); + } - // Further normalize the dscale to be used in dx calculation - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dscale, smem, thread_in_cta_c); + __syncthreads(); - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - float y[ELEMENTS_PER_LDG]; - zero_array(y); - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]); - } - } - } + // dbias parallel sum + if (params.sync_iters > 0) { + ParallelSums::dispatchX( + smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, + params.sync_iters); + } else { + ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); + } - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } + __syncthreads(); + // The values in shared memory correspond to the CTA-wide sums. + read_from_smem(dbias, smem, thread_in_cta_c); + + // Normalize the dscale. + float var[ELEMENTS_PER_LDG]; + zero_array(var); + if (is_valid_c) { + read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); + } + multiply(dscale, var); + + // store dscale/dbias + bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; + if (is_valid_for_saving) { + if (params.sync_iters > 0) { + scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); + scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); + } else { + write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); + write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); + } + } + + // Further normalize the dscale to be used in dx calculation + float scale[ELEMENTS_PER_LDG]; + zero_array(scale); + if (is_valid_c) { + read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); + } + multiply(dscale, var); + // scale the inv-var as well, afterwards + multiply(var, scale); + + // inverse count + float inv_count = params.svar_inv_count; + + // The base pointer to write to. + uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; + +// Store the elements in registers. +#pragma unroll 1 + for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { + // The value for nhw. + int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; + +// Normalize the elements and write to memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; + // Convert to float. + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage[i]); + to_float(dy_math, dy_storage[i]); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + if (is_valid) { + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + + // The next value of nhw. + out_nhw -= pixels_per_iteration; + +// Read the next elements from memory. +#pragma unroll + for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { + const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + float y[ELEMENTS_PER_LDG]; + zero_array(y); + if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { + ldg_stream(x_storage[i], &gmem_src[idx * params.c]); + ldg_stream(dy_storage[i], &gmem_dst1[idx * params.c]); } - // We're about to start on the next c-blk. Needed? - __syncthreads(); + } } + + // Normalize the elements from SMEM and write them out. + if (pixels_in_smem > 0) { + for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { + const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; + const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; + if (is_valid) { + // Read from SMEM. + int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; + read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); + offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; + read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); + float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; + to_float(x_math, x_storage_local); + to_float(dy_math, dy_storage_local); + + float dx[ELEMENTS_PER_LDG]; + bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); + + // Write back. + stg_stream(&gmem_dst[idx * params.c], dx); + } + } + } + // We're about to start on the next c-blk. Needed? + __syncthreads(); + } } #endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp index c8b045b1a..4f2ae68df 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -1,145 +1,81 @@ #include -#include #include +#include -void index_mul_2d_float_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_float_foward_cuda(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1); -void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1); -void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, + const at::Tensor &in2, const at::Tensor &idx1); -void index_mul_2d_half_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_half_foward_cuda(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1); -void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1); -void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); +void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, + const at::Tensor &in2, const at::Tensor &idx1); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -void index_mul_2d_float_forward( - at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ +void index_mul_2d_float_forward(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { return index_mul_2d_float_foward_cuda(out, in1, in2, idx1); } -void index_mul_2d_float_backward( - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ +void index_mul_2d_float_backward(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } -void index_mul_2d_float_backwrad_backward( - at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); +void index_mul_2d_float_backwrad_backward(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1) { + return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); } -void index_mul_2d_half_forward( - at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ +void index_mul_2d_half_forward(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { return index_mul_2d_half_foward_cuda(out, in1, in2, idx1); } -void index_mul_2d_half_backward( - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ +void index_mul_2d_half_backward(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } -void index_mul_2d_half_backwrad_backward( - at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); +void index_mul_2d_half_backwrad_backward(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1) { + return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("float_forward", &index_mul_2d_float_forward, - "index mul float calculation forward (CUDA)", + m.def("float_forward", &index_mul_2d_float_forward, "index mul float calculation forward (CUDA)", py::call_guard()); - m.def("float_backward", &index_mul_2d_float_backward, - "index mul float calculation backward (CUDA)", + m.def("float_backward", &index_mul_2d_float_backward, "index mul float calculation backward (CUDA)", py::call_guard()); m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, - "index mul float calculation backward backward (CUDA)", + "index mul float calculation backward backward (CUDA)", py::call_guard()); + m.def("half_forward", &index_mul_2d_half_forward, "index mul half calculation forward (CUDA)", py::call_guard()); - m.def("half_forward", &index_mul_2d_half_forward, - "index mul half calculation forward (CUDA)", - py::call_guard()); - m.def("half_backward", &index_mul_2d_half_backward, - "index mul half calculation backward (CUDA)", + m.def("half_backward", &index_mul_2d_half_backward, "index mul half calculation backward (CUDA)", py::call_guard()); m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, - "index mul half calculation backward backward (CUDA)", - py::call_guard()); + "index mul half calculation backward backward (CUDA)", py::call_guard()); } - diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu index 072f124dc..5c0a27849 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu @@ -1,479 +1,391 @@ #include #include #include -#include +#include -__global__ void index_mul_2d_float_dim64( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 res, src1, src2; - src1 = reinterpret_cast(in1)[vec_idx1]; - src2 = reinterpret_cast(in2)[vec_idx2]; - res.x = src1.x * src2.x; - res.y = src1.y * src2.y; - res.z = src1.z * src2.z; - res.w = src1.w * src2.w; - reinterpret_cast(out)[vec_idx2] = res; - } +__global__ void index_mul_2d_float_dim64(float *out, const float *in1, const float *in2, const int64_t *idx1, + const int64_t size) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 res, src1, src2; + src1 = reinterpret_cast(in1)[vec_idx1]; + src2 = reinterpret_cast(in2)[vec_idx2]; + res.x = src1.x * src2.x; + res.y = src1.y * src2.y; + res.z = src1.z * src2.z; + res.w = src1.w * src2.w; + reinterpret_cast(out)[vec_idx2] = res; + } } -__global__ void index_mul_2d_float( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim); - int64_t vec_idx2 = (start_idx * fea_dim); - - for (int i = tidx; i < fea_dim; i += stride) { - out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; - } - } -} +__global__ void index_mul_2d_float(float *out, const float *in1, const float *in2, const int64_t *idx1, + const int64_t size, const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; -__global__ void index_mul_2d_half( - at::Half *out, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim); - int64_t vec_idx2 = (start_idx * fea_dim); - - for (int i = tidx; i < fea_dim; i += stride) { - out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); - } - } -} + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim); + int64_t vec_idx2 = (start_idx * fea_dim); -__global__ void index_mul_2d_grad_float_dim64( - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 src_in1, src_in2, src_grad_out, dst_grad_in2; - src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; - src_in1 = reinterpret_cast(in1)[vec_idx1]; - src_in2 = reinterpret_cast(in2)[vec_idx2]; - int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w); - dst_grad_in2.x = src_grad_out.x * src_in1.x; - dst_grad_in2.y = src_grad_out.y * src_in1.y; - dst_grad_in2.z = src_grad_out.z * src_in1.z; - dst_grad_in2.w = src_grad_out.w * src_in1.w; - reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + for (int i = tidx; i < fea_dim; i += stride) { + out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; } + } } -__global__ void index_mul_2d_grad_float( - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_in1 = in1[vec_idx1 + i]; - float src_in2 = in2[vec_idx2 + i]; - float src_grad_out = grad_out[vec_idx2 + i]; - grad_in2[vec_idx2 + i] = src_grad_out * src_in1; - gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2); - } +__global__ void index_mul_2d_half(at::Half *out, const at::Half *in1, const at::Half *in2, const int64_t *idx1, + const int64_t size, const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim); + int64_t vec_idx2 = (start_idx * fea_dim); + + for (int i = tidx; i < fea_dim; i += stride) { + out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); } + } } -__global__ void index_mul_2d_grad_half( - at::Half *grad_in1, - at::Half *grad_in2, - const at::Half *grad_out, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_in1 = static_cast(in1[vec_idx1 + i]); - float src_in2 = static_cast(in2[vec_idx2 + i]); - float src_grad_out = static_cast(grad_out[vec_idx2 + i]); - grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1); - gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2)); - } - } +__global__ void index_mul_2d_grad_float_dim64(float *grad_in1, float *grad_in2, const float *grad_out, const float *in1, + const float *in2, const int64_t *idx1, const int64_t size) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 src_in1, src_in2, src_grad_out, dst_grad_in2; + src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; + src_in1 = reinterpret_cast(in1)[vec_idx1]; + src_in2 = reinterpret_cast(in2)[vec_idx2]; + int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w); + dst_grad_in2.x = src_grad_out.x * src_in1.x; + dst_grad_in2.y = src_grad_out.y * src_in1.y; + dst_grad_in2.z = src_grad_out.z * src_in1.z; + dst_grad_in2.w = src_grad_out.w * src_in1.w; + reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + } } -__global__ void index_mul_2d_grad_grad_float_dim64( - float *grad_grad_out, - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *grad_grad_in1, - const float *grad_grad_in2, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out; - float4 dst_grad_grad_out, dst_grad_in2; - src_grad_grad_in1 = reinterpret_cast(grad_grad_in1)[vec_idx1]; - src_in1 = reinterpret_cast(in1)[vec_idx1]; - src_grad_grad_in2 = reinterpret_cast(grad_grad_in2)[vec_idx2]; - src_in2 = reinterpret_cast(in2)[vec_idx2]; - dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x; - dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y; - dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z; - dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w; - reinterpret_cast(grad_grad_out)[vec_idx2] = dst_grad_grad_out; - src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; - int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w); - dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x; - dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y; - dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z; - dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w; - reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; +__global__ void index_mul_2d_grad_float(float *grad_in1, float *grad_in2, const float *grad_out, const float *in1, + const float *in2, const int64_t *idx1, const int64_t size, + const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_in1 = in1[vec_idx1 + i]; + float src_in2 = in2[vec_idx2 + i]; + float src_grad_out = grad_out[vec_idx2 + i]; + grad_in2[vec_idx2 + i] = src_grad_out * src_in1; + gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2); } + } } -__global__ void index_mul_2d_grad_grad_float( - float *grad_grad_out, - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *grad_grad_in1, - const float *grad_grad_in2, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i]; - float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i]; - float src_in1 = in1[vec_idx1 + i]; - float src_in2 = in2[vec_idx2 + i]; - float src_grad_out = grad_out[vec_idx2 + i]; - grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1; - grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out; - gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out); - } +__global__ void index_mul_2d_grad_half(at::Half *grad_in1, at::Half *grad_in2, const at::Half *grad_out, + const at::Half *in1, const at::Half *in2, const int64_t *idx1, + const int64_t size, const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_in1 = static_cast(in1[vec_idx1 + i]); + float src_in2 = static_cast(in2[vec_idx2 + i]); + float src_grad_out = static_cast(grad_out[vec_idx2 + i]); + grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1); + gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2)); } + } } -__global__ void index_mul_2d_grad_grad_half( - at::Half *grad_grad_out, - at::Half *grad_in1, - at::Half *grad_in2, - const at::Half *grad_out, - const at::Half *grad_grad_in1, - const at::Half *grad_grad_in2, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_grad_grad_in1 = static_cast(grad_grad_in1[vec_idx1 + i]); - float src_grad_grad_in2 = static_cast(grad_grad_in2[vec_idx2 + i]); - float src_in1 = static_cast(in1[vec_idx1 + i]); - float src_in2 = static_cast(in2[vec_idx2 + i]); - float src_grad_out = static_cast(grad_out[vec_idx2 + i]); - grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1); - grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out); - gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out)); - } - } +__global__ void index_mul_2d_grad_grad_float_dim64(float *grad_grad_out, float *grad_in1, float *grad_in2, + const float *grad_out, const float *grad_grad_in1, + const float *grad_grad_in2, const float *in1, const float *in2, + const int64_t *idx1, const int64_t size) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out; + float4 dst_grad_grad_out, dst_grad_in2; + src_grad_grad_in1 = reinterpret_cast(grad_grad_in1)[vec_idx1]; + src_in1 = reinterpret_cast(in1)[vec_idx1]; + src_grad_grad_in2 = reinterpret_cast(grad_grad_in2)[vec_idx2]; + src_in2 = reinterpret_cast(in2)[vec_idx2]; + dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x; + dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y; + dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z; + dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w; + reinterpret_cast(grad_grad_out)[vec_idx2] = dst_grad_grad_out; + src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; + int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w); + dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x; + dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y; + dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z; + dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w; + reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + } } -void index_mul_2d_float_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; +__global__ void index_mul_2d_grad_grad_float(float *grad_grad_out, float *grad_in1, float *grad_in2, + const float *grad_out, const float *grad_grad_in1, + const float *grad_grad_in2, const float *in1, const float *in2, + const int64_t *idx1, const int64_t size, const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i]; + float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i]; + float src_in1 = in1[vec_idx1 + i]; + float src_in2 = in2[vec_idx2 + i]; + float src_grad_out = grad_out[vec_idx2 + i]; + grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1; + grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out; + gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out); } + } +} - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - - index_mul_2d_float_dim64<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - - index_mul_2d_float<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size, fea_dim); +__global__ void index_mul_2d_grad_grad_half(at::Half *grad_grad_out, at::Half *grad_in1, at::Half *grad_in2, + const at::Half *grad_out, const at::Half *grad_grad_in1, + const at::Half *grad_grad_in2, const at::Half *in1, const at::Half *in2, + const int64_t *idx1, const int64_t size, const int64_t fea_dim) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_grad_grad_in1 = static_cast(grad_grad_in1[vec_idx1 + i]); + float src_grad_grad_in2 = static_cast(grad_grad_in2[vec_idx2 + i]); + float src_in1 = static_cast(in1[vec_idx1 + i]); + float src_in2 = static_cast(in2[vec_idx2 + i]); + float src_grad_out = static_cast(grad_out[vec_idx2 + i]); + grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1); + grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out); + gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out)); } - - AT_CUDA_CHECK(cudaGetLastError()); + } } -void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } +void index_mul_2d_float_foward_cuda(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - index_mul_2d_grad_float_dim64<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); + index_mul_2d_float_dim64<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); + } else { + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - AT_CUDA_CHECK(cudaGetLastError()); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + index_mul_2d_float<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + } - index_mul_2d_grad_float<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); - } + AT_CUDA_CHECK(cudaGetLastError()); } -void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } +void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - - index_mul_2d_grad_grad_float_dim64<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - - index_mul_2d_grad_grad_float<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); - } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + + index_mul_2d_grad_float_dim64<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), + in2.data_ptr(), idx1.data_ptr(), size); AT_CUDA_CHECK(cudaGetLastError()); + } else { + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + + index_mul_2d_grad_float<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), + in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + } } -void index_mul_2d_half_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } +void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, + const at::Tensor &in2, const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + index_mul_2d_grad_grad_float_dim64<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); + } else { const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - index_mul_2d_half<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size, fea_dim); + index_mul_2d_grad_grad_float<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + } - AT_CUDA_CHECK(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); } -void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } +void index_mul_2d_half_foward_cuda(at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + + index_mul_2d_half<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, + fea_dim); - index_mul_2d_grad_half<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + AT_CUDA_CHECK(cudaGetLastError()); } -void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } +void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, at::Tensor &grad_in2, const at::Tensor &grad_out, + const at::Tensor &in1, const at::Tensor &in2, const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - index_mul_2d_grad_grad_half<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + index_mul_2d_grad_half<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); +} - AT_CUDA_CHECK(cudaGetLastError()); +void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, at::Tensor &grad_in1, at::Tensor &grad_in2, + const at::Tensor &grad_out, const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, const at::Tensor &in1, + const at::Tensor &in2, const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0) { + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + + index_mul_2d_grad_grad_half<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + + AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index cf0355c07..adf24969f 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -1,105 +1,99 @@ #pragma once +#include +#include #include #include + #include -#include -#include namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct LaunchParams{ +template +struct LaunchParams { + size_t workspace_bytes; + size_t barrier_size; - size_t workspace_bytes; - size_t barrier_size; + cudaDeviceProp *props; - cudaDeviceProp * props; - - cudaStream_t stream; - - Params params; + cudaStream_t stream; + Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct FwdParams{ - FwdParams() - : ctas_per_col(0) - , rows(0) - , cols(0) - , x(nullptr) - , z(nullptr) - , mu(nullptr) - , rs(nullptr) - , gamma(nullptr) - , beta(nullptr) - , workspace(nullptr) - , barrier(nullptr) - , epsilon(0.f) - { - } - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *z; - void *mu; - void *rs; - void *gamma; - void *beta; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Output of LN FWD. - float epsilon; +struct FwdParams { + FwdParams() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + z(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + beta(nullptr), + workspace(nullptr), + barrier(nullptr), + epsilon(0.f) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x; + void *z; + void *mu; + void *rs; + void *gamma; + void *beta; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + + // Output of LN FWD. + float epsilon; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct BwdParams : public FwdParams{ - BwdParams() - : FwdParams() - , dz(nullptr) - , dbeta_part(nullptr) - , dgamma_part(nullptr) - , dx(nullptr) - , dbeta(nullptr) - , dgamma(nullptr) - { - } - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; - +struct BwdParams : public FwdParams { + BwdParams() + : FwdParams(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + // Input: gradient wrt. LN FWD output. + void *dz; + + // Workspace for Wgrad pre-reduction. + void *dbeta_part; + void *dgamma_part; + + // Output: Dgrad. + void *dx; + // Output: Wgrad. + void *dbeta; + void *dgamma; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool)>; +using FwdFunction = std::function &, const bool)>; +using BwdFunction = std::function &, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; using BwdRegistry = std::unordered_map; @@ -115,77 +109,77 @@ using bf16 = nv_bfloat16; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct TypeId{}; +template +struct TypeId {}; -template<> -struct TypeId{ - constexpr static uint32_t Value = 0; +template <> +struct TypeId { + constexpr static uint32_t Value = 0; }; -template<> -struct TypeId{ - constexpr static uint32_t Value = 1; +template <> +struct TypeId { + constexpr static uint32_t Value = 1; }; -template<> -struct TypeId{ - constexpr static uint32_t Value = 2; +template <> +struct TypeId { + constexpr static uint32_t Value = 2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Type2Key{ - constexpr static uint32_t Value = TypeId::Value << S; +template +struct Type2Key { + constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WeightType2Key : public Type2Key{}; +template +struct WeightType2Key : public Type2Key {}; -template -struct InputType2Key : public Type2Key{}; +template +struct InputType2Key : public Type2Key {}; -template -struct OutputType2Key : public Type2Key{}; +template +struct OutputType2Key : public Type2Key {}; -template -struct ComputeType2Key : public Type2Key{}; +template +struct ComputeType2Key : public Type2Key {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Types2Key{ - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size){ - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } +template +struct Types2Key { + constexpr static uint32_t Value = + WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size) { + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct FwdRegistrar{ - FwdRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_FUNCS.insert({ key, f }); - } +template +struct FwdRegistrar { + FwdRegistrar(FwdFunction f) { + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_FUNCS.insert({key, f}); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct BwdRegistrar{ - BwdRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_FUNCS.insert({ key, f }); - } +template +struct BwdRegistrar { + BwdRegistrar(BwdFunction f) { + uint64_t key = Types2Key::get(HIDDEN_SIZE); + BWD_FUNCS.insert({key, f}); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm - diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 4ca810976..6e68b59c0 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -1,19 +1,19 @@ #include -#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/CUDAContext.h" #include "ln.h" /* Supported Type combinations: -input compute weights output +input compute weights output ======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 Remarks: Output type = Weight type @@ -30,126 +30,128 @@ BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// -uint32_t get_type_id(torch::Dtype dtype){ - if( dtype == torch::kFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kBFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kFloat32 ) { - return TypeId::Value; - } else { - TORCH_CHECK(false, "Type not supported: ", dtype); - } +uint32_t get_type_id(torch::Dtype dtype) { + if (dtype == torch::kFloat16) { + return TypeId::Value; + } else if (dtype == torch::kBFloat16) { + return TypeId::Value; + } else if (dtype == torch::kFloat32) { + return TypeId::Value; + } else { + TORCH_CHECK(false, "Type not supported: ", dtype); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; + using namespace layer_norm; + uint64_t type_key = + get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; } } // namespace layer_norm //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); - if( iter != layer_norm::FWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); - } +layer_norm::FwdFunction &get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, + torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); + if (iter != layer_norm::FWD_FUNCS.end()) { + return iter->second; + } else { + TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); - if( iter != layer_norm::BWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); - } +layer_norm::BwdFunction &get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, + torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); + if (iter != layer_norm::BWD_FUNCS.end()) { + return iter->second; + } else { + TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size - const at::Tensor &gamma, // hidden_size + const at::Tensor &gamma, // hidden_size const at::Tensor &beta, // hidden_size - const float epsilon -) { - auto itype = x.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = wtype; - auto ctype = torch::kFloat32; + const float epsilon) { + auto itype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = wtype; + auto ctype = torch::kFloat32; - TORCH_CHECK(beta.scalar_type() == wtype); + TORCH_CHECK(beta.scalar_type() == wtype); - TORCH_CHECK(x.is_cuda()) - TORCH_CHECK(gamma.is_cuda()) - TORCH_CHECK(beta.is_cuda()) + TORCH_CHECK(x.is_cuda()) + TORCH_CHECK(gamma.is_cuda()) + TORCH_CHECK(beta.is_cuda()) - TORCH_CHECK(x.is_contiguous()); - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); + TORCH_CHECK(x.is_contiguous()); + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); - const int rows = sizes[0]; - const int cols = sizes[1]; - auto hidden_size = gamma.numel(); + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma.numel(); - TORCH_CHECK(gamma.sizes() == beta.sizes()); - TORCH_CHECK(hidden_size == cols); + TORCH_CHECK(gamma.sizes() == beta.sizes()); + TORCH_CHECK(hidden_size == cols); - TORCH_CHECK(epsilon >= 0.f); + TORCH_CHECK(epsilon >= 0.f); - auto opts = x.options(); + auto opts = x.options(); - auto z = torch::empty(sizes, opts.dtype(otype)); + auto z = torch::empty(sizes, opts.dtype(otype)); - auto mu = torch::empty({ rows }, opts.dtype(ctype)); - auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); + auto mu = torch::empty({rows}, opts.dtype(ctype)); + auto rsigma = torch::empty({rows}, opts.dtype(ctype)); - layer_norm::LaunchParams launch_params; + layer_norm::LaunchParams launch_params; - launch_params.props = at::cuda::getCurrentDeviceProperties(); - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - // Request the kernel launcher. - auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size); + // Request the kernel launcher. + auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size); - // Query the kernel-specific launch parameters. - launcher(launch_params, true); + // Query the kernel-specific launch parameters. + launcher(launch_params, true); - at::Tensor workspace, barrier; + at::Tensor workspace, barrier; - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.z = z.data_ptr(); - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.beta = beta.data_ptr(); - params.x = x.data_ptr(); - params.epsilon = epsilon; + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.z = z.data_ptr(); + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.beta = beta.data_ptr(); + params.x = x.data_ptr(); + params.epsilon = epsilon; - if( launch_params.barrier_size > 0 ) { - auto options = x.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } + if (launch_params.barrier_size > 0) { + auto options = x.options(); + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } - // Launch the kernel. - launcher(launch_params, false); + // Launch the kernel. + launcher(launch_params, false); - return { z, mu, rsigma }; + return {z, mu, rsigma}; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -158,100 +160,98 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxh c10::optional &mu_, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size - c10::optional&beta_, // hidden_size - bool memory_efficient -) { - - auto itype = x_or_z.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = wtype; - auto ctype = torch::kFloat32; - - TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(rsigma.dtype() == ctype); - if (mu_.has_value()) { - TORCH_CHECK(mu_.value().dtype() == ctype); - } - - TORCH_CHECK(x_or_z.is_cuda()); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); - if (beta_.has_value()) { - TORCH_CHECK(beta_.value().is_cuda()); - TORCH_CHECK(beta_.value().dtype() == wtype); - } - - TORCH_CHECK(x_or_z.is_contiguous()); - TORCH_CHECK(dz.is_contiguous()); - - auto sizes = x_or_z.sizes(); - TORCH_CHECK(sizes.size() == 2); - TORCH_CHECK(dz.sizes() == sizes); - auto rows = sizes[0]; - auto cols = sizes[1]; - - auto hidden_size = gamma.numel(); - - TORCH_CHECK(gamma.numel() == cols); - if (beta_.has_value()) { - TORCH_CHECK(beta_.value().numel() == cols); - } - - auto options = x_or_z.options(); - - auto dx = torch::empty_like(x_or_z); - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); - - layer_norm::LaunchParams launch_params; - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - launch_params.props = at::cuda::getCurrentDeviceProperties(); - - auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size); - - launcher(launch_params, true); - - auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); - auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); - at::Tensor workspace, barrier; - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - if (memory_efficient) { - params.z = x_or_z.data_ptr(); - params.beta = beta_.value().data_ptr(); - } else { - params.x = x_or_z.data_ptr(); - params.mu = mu_.value().data_ptr(); - } - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.dz = dz.data_ptr(); - params.dx = dx.data_ptr(); - params.dbeta = dbeta.data_ptr(); - params.dgamma = dgamma.data_ptr(); - params.dbeta_part = dbeta_part.data_ptr(); - params.dgamma_part = dgamma_part.data_ptr(); - - if( launch_params.barrier_size > 0 ) { - // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - launcher(launch_params, false); - - return { dx, dgamma, dbeta, dgamma_part, dbeta_part }; + c10::optional &beta_, // hidden_size + bool memory_efficient) { + auto itype = x_or_z.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = wtype; + auto ctype = torch::kFloat32; + + TORCH_CHECK(dz.dtype() == otype); + TORCH_CHECK(rsigma.dtype() == ctype); + if (mu_.has_value()) { + TORCH_CHECK(mu_.value().dtype() == ctype); + } + + TORCH_CHECK(x_or_z.is_cuda()); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().is_cuda()); + TORCH_CHECK(beta_.value().dtype() == wtype); + } + + TORCH_CHECK(x_or_z.is_contiguous()); + TORCH_CHECK(dz.is_contiguous()); + + auto sizes = x_or_z.sizes(); + TORCH_CHECK(sizes.size() == 2); + TORCH_CHECK(dz.sizes() == sizes); + auto rows = sizes[0]; + auto cols = sizes[1]; + + auto hidden_size = gamma.numel(); + + TORCH_CHECK(gamma.numel() == cols); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().numel() == cols); + } + + auto options = x_or_z.options(); + + auto dx = torch::empty_like(x_or_z); + auto dgamma = torch::empty_like(gamma); + auto dbeta = torch::empty_like(gamma); + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + + auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size); + + launcher(launch_params, true); + + auto dgamma_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + auto dbeta_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + if (memory_efficient) { + params.z = x_or_z.data_ptr(); + params.beta = beta_.value().data_ptr(); + } else { + params.x = x_or_z.data_ptr(); + params.mu = mu_.value().data_ptr(); + } + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.dz = dz.data_ptr(); + params.dx = dx.data_ptr(); + params.dbeta = dbeta.data_ptr(); + params.dgamma = dgamma.data_ptr(); + params.dbeta_part = dbeta_part.data_ptr(); + params.dgamma_part = dgamma_part.data_ptr(); + + if (launch_params.barrier_size > 0) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false); + + return {dx, dgamma, dbeta, dgamma_part, dbeta_part}; } //////////////////////////////////////////////////////////////////////////////////////////////////// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; + m.doc() = "CUDA LayerNorm"; m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel", py::call_guard()); m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel", py::call_guard()); } diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 019764a38..ba307b559 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -1,324 +1,317 @@ #pragma once +#include "ln_utils.cuh" + namespace layer_norm { -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_bwd_kernel(layer_norm::BwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - - compute_t * smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - constexpr float rn = 1.f / float(COLS); - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - gamma[it].load_from(params.gamma, idx); - if (params.z != nullptr) { - beta[it].load_from(params.beta, idx); - } - idx += Ktraits::VEC_COLS_PER_LDG; +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dzy_sum[LDGS]; + Cvec dz_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + memset(dz_sum, 0, sizeof(dz_sum)); + + compute_t *smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + constexpr float rn = 1.f / float(COLS); + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + index_t idx = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + gamma[it].load_from(params.gamma, idx); + if (params.z != nullptr) { + beta[it].load_from(params.beta, idx); + } + idx += Ktraits::VEC_COLS_PER_LDG; + } +// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the +// last blocks with syncthreads! +// grid stride over rows +#pragma unroll 1 + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + const compute_t mu_r = params.z == nullptr ? static_cast(params.mu)[row] : 0.f; + const compute_t rs_r = static_cast(params.rs)[row]; + Ivec x_or_z[LDGS]; + Ovec dz[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz[it].load_from(params.dz, idx); + if (params.z != nullptr) { + x_or_z[it].load_from(params.z, idx); + } else { + x_or_z[it].load_from(params.x, idx); + } + idx += Ktraits::VEC_COLS_PER_LDG; + } + + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + + compute_t mdy_local = 0.f; + compute_t mdyy_local = 0.f; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]); + compute_t beta_tmp = compute_t(beta[it].data.elt[jt]); + compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]); + compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r); + compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp; + compute_t dz_tmp = dz[it].data.elt[jt]; + + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + dz_sum[it].data.elt[jt] += dz_tmp; + } } - // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the - // last blocks with syncthreads! - // grid stride over rows - #pragma unroll 1 - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = params.z == nullptr ? static_cast(params.mu)[row] : 0.f; - const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x_or_z[LDGS]; - Ovec dz[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz[it].load_from(params.dz, idx); - if (params.z != nullptr) { - x_or_z[it].load_from(params.z, idx); - } else { - x_or_z[it].load_from(params.x, idx); - } - idx += Ktraits::VEC_COLS_PER_LDG; - } - - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]); - compute_t beta_tmp = compute_t(beta[it].data.elt[jt]); - compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]); - compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r); - compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp; - compute_t dz_tmp = dz[it].data.elt[jt]; - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; - - Ivec dx[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - dx[it].data.elt[jt] = dx_tmp; - } - dx[it].store_to(params.dx, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - - } // end: grid stride loop - - if( WARPS_M == 1 ) { - idx = r * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(params.dbeta_part, idx); - dzy_sum[it].store_to(params.dgamma_part, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz_sum[NUM_RES]; - memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; - for( int jt = 0; jt < NUM_RES; jt++ ) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - } - - compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; - for( int jt = 0; jt < NUM_RES; jt++ ) { - *dbeta_part = cta_dz_sum[jt]; - dbeta_part += Ktraits::THREADS_PER_CTA; - } + + reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); + mdy_local = layer_norm::Get<0>::of(result) * rn; + mdyy_local = layer_norm::Get<1>::of(result) * rn; + + Ivec dx[LDGS]; + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); + dx[it].data.elt[jt] = dx_tmp; + } + dx[it].store_to(params.dx, idx); + idx += Ktraits::VEC_COLS_PER_LDG; } + + } // end: grid stride loop + + if (WARPS_M == 1) { + idx = r * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz_sum[it].store_to(params.dbeta_part, idx); + dzy_sum[it].store_to(params.dgamma_part, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz_sum[NUM_RES]; + memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; + } + + compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dbeta_part = cta_dz_sum[jt]; + dbeta_part += Ktraits::THREADS_PER_CTA; + } + } } -template -__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) -void ln_bwd_finalize_kernel(BwdParams params) -{ - - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - memset(&dbeta_local, 0, sizeof(dbeta_local)); - for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { - index_t idx = row * Kernel_traits::COLS + col; - - Vec dbeta_part, dgamma_part; - dbeta_part.load_from(params.dbeta_part, idx); - dgamma_part.load_from(params.dgamma_part, idx); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; - } - } - - void * smem_gamma = smem_; - void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma_local.store_to(smem_gamma, write_idx); - dbeta_local.store_to(smem_beta, write_idx); - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta and smem_gamma - void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - - - // More than one iter iff ROWS_PER_CTA < 32. - for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta_local, 0, sizeof(dbeta_local)); - memset(&dgamma_local, 0, sizeof(dgamma_local)); - - // Load beta and gamma transposed - if(read_row < Kernel_traits::ROWS_PER_CTA){ - dbeta_local.load_from(smem_beta, read_idx); - dgamma_local.load_from(smem_gamma, read_idx); - } - - // Call reducer on the loaded value(s) and convert. - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); - - dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; - } - - // Leader stores the result at the current column. - if(lane == 0){ - dgamma_local.store_to(smem_gamma_out, w); - dbeta_local.store_to(smem_beta_out, w); - } - - } - - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { - - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2; - Vec dbeta_out2, dgamma_out2; - - dgamma_vec2.load_from(smem_gamma_out, lane); - dbeta_vec2.load_from(smem_beta_out, lane); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); - dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); - } - dgamma_out2.store_to(params.dgamma, col_out); - dbeta_out2.store_to(params.dbeta, col_out); - - } +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) { + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2) { + // Each thread sums over NUM_ELT columns. + Vec dbeta_local, dgamma_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + memset(&dbeta_local, 0, sizeof(dbeta_local)); + for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { + index_t idx = row * Kernel_traits::COLS + col; + + Vec dbeta_part, dgamma_part; + dbeta_part.load_from(params.dbeta_part, idx); + dgamma_part.load_from(params.dgamma_part, idx); +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + } + } + + void *smem_gamma = smem_; + void *smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; + + dgamma_local.store_to(smem_gamma, write_idx); + dbeta_local.store_to(smem_beta, write_idx); + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta and smem_gamma + void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void *smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + + // More than one iter iff ROWS_PER_CTA < 32. + for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + memset(&dbeta_local, 0, sizeof(dbeta_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); + + // Load beta and gamma transposed + if (read_row < Kernel_traits::ROWS_PER_CTA) { + dbeta_local.load_from(smem_beta, read_idx); + dgamma_local.load_from(smem_gamma, read_idx); + } + +// Call reducer on the loaded value(s) and convert. +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + compute_t b_i = dbeta_local.data.elt[it]; + compute_t g_i = dgamma_local.data.elt[it]; + b_i = reducer.allreduce(b_i, sum); + g_i = reducer.allreduce(g_i, sum); + + dgamma_local.data.elt[it] = g_i; + dbeta_local.data.elt[it] = b_i; + } + + // Leader stores the result at the current column. + if (lane == 0) { + dgamma_local.store_to(smem_gamma_out, w); + dbeta_local.store_to(smem_beta_out, w); + } + } + + // All writes done. + __syncthreads(); + + // Pack and store: 2-wide stores with half the threads. + if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dbeta_vec2, dgamma_vec2; + Vec dbeta_out2, dgamma_out2; + + dgamma_vec2.load_from(smem_gamma_out, lane); + dbeta_vec2.load_from(smem_beta_out, lane); +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + } + dgamma_out2.store_to(params.dgamma, col_out); + dbeta_out2.store_to(params.dbeta, col_out); } + } } } // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 3893d4e0c..405363982 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -1,152 +1,125 @@ #include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" #include "ln_bwd_kernels.cuh" +#include "ln_kernel_traits.h" +#include "ln_utils.cuh" using namespace layer_norm; -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_kernel; - - if( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; +template +void launch_(LaunchParams &launch_params, const bool configure_params) { + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_col = + launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2; } - - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if (Kernel_traits::CTAS_PER_ROW == 1) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } + + using Kernel_traits_f = + layer_norm::Kernel_traits_finalize; + + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); } // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL -REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_BWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_BWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_BWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_BWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_BWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); @@ -161,22 +134,22 @@ REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); @@ -185,9 +158,9 @@ REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); @@ -237,4 +210,3 @@ REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index dc4e89cf5..232de003e 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -1,140 +1,119 @@ #include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" #include "ln_fwd_kernels.cuh" +#include "ln_kernel_traits.h" +#include "ln_utils.cuh" using namespace layer_norm; -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_kernel; - - if( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; +template +void launch_(LaunchParams &launch_params, const bool configure_params) { + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_col = + launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2; } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } - + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if (Kernel_traits::CTAS_PER_ROW == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } } // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG -REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_FWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); @@ -148,23 +127,23 @@ REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); @@ -190,17 +169,17 @@ REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); @@ -225,4 +204,3 @@ REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); - diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh index 64e72974f..08f1e5ee0 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh @@ -1,110 +1,109 @@ #pragma once #include "ln.h" +#include "ln_utils.cuh" namespace layer_norm { -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_kernel(FwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - gamma[it].load_from(params.gamma, idx); - beta[it].load_from(params.beta, idx); - idx += VEC_COLS_PER_LDG; +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + index_t idx = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + gamma[it].load_from(params.gamma, idx); + beta[it].load_from(params.beta, idx); + idx += VEC_COLS_PER_LDG; + } + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + Ivec x[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; + compute_t xf[LDGS * NUM_ELTS]; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + x[it].load_from(params.x, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_ij = compute_t(x[it].data.elt[jt]); + xf[it * NUM_ELTS + jt] = x_ij; + } + idx += VEC_COLS_PER_LDG; } - constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); - - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - Ivec x[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - x[it].load_from(params.x, idx); - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_ij = compute_t(x[it].data.elt[jt]); - xf[it * NUM_ELTS + jt] = x_ij; - } - idx += VEC_COLS_PER_LDG; - } - - stats_t s = stats.compute(xf, rn); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(rn * m2 + params.epsilon); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; - } - - Ovec z[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); - output_t g_ij = gamma[it].data.elt[jt]; - output_t b_ij = beta[it].data.elt[jt]; - z[it].data.elt[jt] = (g_ij * y_ij + b_ij); - } - z[it].store_to(params.z, idx); - idx += VEC_COLS_PER_LDG; - } + stats_t s = stats.compute(xf, rn); + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if (bidn == 0 && warp_n == 0 && lane == 0) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(rn * m2 + params.epsilon); + + if (bidn == 0 && warp_n == 0 && lane == 0) { + rs_ptr[row] = rs; + } + + Ovec z[LDGS]; + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); + output_t g_ij = gamma[it].data.elt[jt]; + output_t b_ij = beta[it].data.elt[jt]; + z[it].data.elt[jt] = (g_ij * y_ij + b_ij); + } + z[it].store_to(params.z, idx); + idx += VEC_COLS_PER_LDG; } + } } } // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h index ed745c5ee..d8bb8a9ac 100644 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h @@ -3,155 +3,114 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// namespace layer_norm { -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_ -> +template struct Kernel_traits_base { - - using weight_t = weight_t_; - using input_t = input_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; - + using weight_t = weight_t_; + using input_t = input_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_, - uint32_t BYTES_PER_LDG_, - typename Base = Kernel_traits_base -> +template > struct Kernel_traits_finalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalsece the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; - - // The type of the reducer. - using Reducer = layer_norm::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert((int)ROWS_PER_CTA <= (int)Base::THREADS_PER_WARP); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = layer_norm::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, - uint32_t BYTES_PER_LDG_ = 16, - typename Base = Kernel_traits_base< - HIDDEN_SIZE_, - weight_t_, - input_t_, - output_t_, - compute_t_, - index_t_, - WARPS_M_*WARPS_N_*THREADS_PER_WARP - > -> +template > struct Kernel_traits : public Base { - - using input_t = typename Base::input_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = layer_norm::Vec; - using Ovec = layer_norm::Vec; - using Wvec = layer_norm::Vec; - using Cvec = layer_norm::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements in the output and weights as in the input. - static_assert(sizeof(input_t) >= sizeof(output_t)); - static_assert(sizeof(input_t) >= sizeof(weight_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - - using Stats = layer_norm::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; - + using input_t = typename Base::input_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename layer_norm::TypeToVec2::Type; + using Reducer = layer_norm::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = layer_norm::Vec; + using Ovec = layer_norm::Vec; + using Wvec = layer_norm::Vec; + using Cvec = layer_norm::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in the input. + static_assert(sizeof(input_t) >= sizeof(output_t)); + static_assert(sizeof(input_t) >= sizeof(weight_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = layer_norm::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh index e18d36de7..5bea8e1cb 100644 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ b/apex/contrib/csrc/layer_norm/ln_utils.cuh @@ -1,10 +1,10 @@ #pragma once -#include - #include #include +#include + #include "ln.h" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -14,16 +14,16 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// inline void check_cuda_(cudaError_t status, const char *file, int line) { - if( status != cudaSuccess ) { - fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); - exit(status); - } + if (status != cudaSuccess) { + fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); + exit(status); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CHECK_CUDA(ans) \ - { check_cuda_((ans), __FILE__, __LINE__); } +#define CHECK_CUDA(ans) \ + { check_cuda_((ans), __FILE__, __LINE__); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -31,79 +31,68 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) { //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_( \ - launch_params, configure_params); \ - } \ - static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_( \ + launch_params, configure_params); \ + } \ + static FwdRegistrar \ + reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_BWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ + BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_(launch_params, configure_params); \ + } \ + static BwdRegistrar \ + reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} +inline __device__ float2 operator+(const float2 &a, const float2 &b) { return {a.x + b.x, a.y + b.y}; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void operator+=(float2 & a, const float2 & b){ - a.x += b.x; - a.y += b.y; +inline __device__ void operator+=(float2 &a, const float2 &b) { + a.x += b.x; + a.y += b.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Sum { - inline __device__ Sum(){} - inline __device__ T operator()(const T &a, const T &b){ - return a + b; - } + inline __device__ Sum() {} + inline __device__ T operator()(const T &a, const T &b) { return a + b; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ - return __shfl_xor_sync(uint32_t(-1), x, idx); +template +inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) { + return __shfl_xor_sync(uint32_t(-1), x, idx); } -template<> -inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ - return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +template <> +inline __device__ float2 warp_shuffle_xor(const float2 &x, uint32_t idx) { + return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; } -template -inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ - return __shfl_down_sync(uint32_t(-1), x, idx); +template +inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) { + return __shfl_down_sync(uint32_t(-1), x, idx); } -template<> -inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ - return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +template <> +inline __device__ float2 warp_shuffle_down(const float2 &x, uint32_t idx) { + return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -113,619 +102,598 @@ namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; + uint4 u; + uint4 v; + uint4 s; + uint4 t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint8 { - uint4 u; - uint4 v; + uint4 u; + uint4 v; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct BytesToType {}; -template<> +template <> struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); + using Type = uint16; + static_assert(sizeof(Type) == 64); }; -template<> +template <> struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); + using Type = uint8; + static_assert(sizeof(Type) == 32); }; -template<> +template <> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> +template <> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> +template <> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> +template <> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> +template <> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct TypeToVec2 {}; -template<> +template <> struct TypeToVec2 { - using Type = float2; + using Type = float2; }; -template<> +template <> struct TypeToVec2 { - using Type = half2; + using Type = half2; }; -template<> +template <> struct TypeToVec2 { - using Type = nv_bfloat162; + using Type = nv_bfloat162; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Get { - template - static inline __device__ R of(const T &vec); + template + static inline __device__ R of(const T &vec); }; -template<> -template +template <> +template inline __device__ R Get<0>::of(const T &vec) { - return vec.x; + return vec.x; } -template<> -template +template <> +template inline __device__ R Get<1>::of(const T &vec) { - return vec.y; + return vec.y; } -template<> -template +template <> +template inline __device__ R Get<2>::of(const T &vec) { - return vec.z; + return vec.z; } -template<> -template +template <> +template inline __device__ R Get<3>::of(const T &vec) { - return vec.w; + return vec.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Converter{ - static inline __device__ Dst convert(const Src &from) { - return Dst(from); - } +template +struct Converter { + static inline __device__ Dst convert(const Src &from) { return Dst(from); } }; -template<> -struct Converter{ - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } +template <> +struct Converter { + static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); } }; -template<> -struct Converter{ - static inline __device__ nv_bfloat162 convert(const float2 &x) { +template <> +struct Converter { + static inline __device__ nv_bfloat162 convert(const float2 &x) { #if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); + return __float22bfloat162_rn(x); #else - union { - nv_bfloat162 raw; - nv_bfloat16 x; - nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(x.x); - tmp.y = __float2bfloat16_rn(x.y); - return tmp.raw; + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; #endif - } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Zeros{ - static inline __device__ T get() { - return T(0.f); - } +template +struct Zeros { + static inline __device__ T get() { return T(0.f); } }; -template<> -struct Zeros{ - static inline __device__ float2 get() { - return make_float2(0.f, 0.f); - } +template <> +struct Zeros { + static inline __device__ float2 get() { return make_float2(0.f, 0.f); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Vec { + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + using Vec_type = typename BytesToType::Type; - using Vec_type = typename BytesToType::Type; + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; + Alias_type data; - Alias_type data; - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } + template + inline __device__ void to(Vec &other) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + other.data.elt[it] = S(this->data.elt[it]); } + } - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } + template + inline __device__ void assign(const Op &op) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = op(it); } + } - inline __device__ void load_from(const void *base_ptr, const size_t idx) { - this->data.vec = static_cast(base_ptr)[idx]; - } + inline __device__ void load_from(const void *base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } - inline __device__ void store_to(void *base_ptr, const size_t idx) { - static_cast(base_ptr)[idx] = this->data.vec; - } + inline __device__ void store_to(void *base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct InterCTASync { - - template - inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) - : phase_counter_(0) - , b0_(params.barrier + bidm) // The barrier for this group of CTAs. - , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. - { - // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + template + inline __device__ InterCTASync(Params ¶ms, uint32_t bidm, uint32_t bidn) + : phase_counter_(0), + b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , + b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for (int found = -1; found != expected;) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); - for( int found = -1; found != expected; ) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); - } - } - - inline __device__ void sync(){ - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : CTAS_PER_ROW; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if( threadIdx.x == 0 ) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); + } + + inline __device__ void sync() { + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if (threadIdx.x == 0) { + spin_wait_(barrier, step, expected); } + // CTA waits for thread 0 + __syncthreads(); + } - int phase_counter_; - int * b0_; - int * b1_; + int phase_counter_; + int *b0_; + int *b1_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer : public Reducer { - - using InterCTASync = InterCTASync; - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) - enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params, bidm, bidn) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - { + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; + + template + inline __device__ Reducer(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, + uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + inter_cta_(params, bidm, bidn), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} + + template + inline __device__ T allreduce(T data, Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; } - - template - inline __device__ T allreduce(T data, Op &op) { - data = Base::reduce(data, op); - // We switch workspace every iteration. - T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if(this->lane_ < CTAS_PER_ROW){ - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if (this->lane_ < CTAS_PER_ROW) { + total = workspace[this->lane_]; } + total = Reducer::allreduce_(total, op); + + return total; + } - InterCTASync inter_cta_; + InterCTASync inter_cta_; - T *w0_; - T *w1_; - int bidn_; + T *w0_; + T *w1_; + int bidn_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer { - - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_n_(warp_n) - , lane_(lane) - { - } - - template - static inline __device__ T allreduce_(T data, Op &op) { - #pragma unroll - for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, + uint32_t lane, void *smem) + : warp_n_(warp_n), lane_(lane) {} + + template + static inline __device__ T allreduce_(T data, Op &op) { +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + data = op(data, warp_shuffle_xor(data, it)); } - - template - inline __device__ T allreduce(T data, Op &op) { - return allreduce_(data, op); - } - - template - inline __device__ T reduce(T data, Op &op){ - // only lane 0 holds the result! - #pragma unroll - for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; + return data; + } + + template + inline __device__ T allreduce(T data, Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, Op &op) { +// only lane 0 holds the result! +#pragma unroll + for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { + data = op(data, warp_shuffle_down(data, it)); } - int warp_n_; - int lane_; + return data; + } + int warp_n_; + int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer : public Reducer { - - using Base = Reducer; - - using Type = T; - - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = &static_cast(smem)[warp_m * WARPS_N]; - smem1_ = smem0_ + WARPS_M * WARPS_N; + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, + uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op &op) { + T *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; } - - template - inline __device__ T allreduce(T data, Op & op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - return out; + __syncthreads(); + T out = Zeros::get(); +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); } - - template - inline __device__ T reduce(T data, Op &op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - } - return out; + return out; + } + + template + inline __device__ T reduce(T data, Op &op) { + T *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; } + __syncthreads(); + T out = Zeros::get(); + if (this->warp_n_ == 0 && this->lane_ == 0) { +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + } + return out; + } - T * smem0_; - T * smem1_; - bool use0_; - + T *smem0_; + T *smem1_; + bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){ - //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) - int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - - #pragma unroll - for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { - // Exchange - T n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. - const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - m_a = __shfl_sync(uint32_t(-1), m_a, 0); - m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) { + // Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + // Exchange + T n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. + const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { - // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. - - using InterCTASync = InterCTASync; - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : inter_cta_(params, bidm, bidn) - , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - , warp_n_(warp_n) - , lane_(lane) - { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, + void *smem) + : inter_cta_(params, bidm, bidn), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), + lane_(lane) {} + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = block_stats; } - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - // TODO rn is not really needed here.. - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn); - - stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if( warp_n_ == 0 && lane_ == 0 ) { - workspace[bidn_] = block_stats; - } + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); - // Assume CTA group size in N less than 32, such that we can finalize with a single warp. - static_assert(CTAS_PER_ROW <= 32); - - // Every warp does the final reduction locally. - if( lane_ < CTAS_PER_ROW ) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } + // Every warp does the final reduction locally. + if (lane_ < CTAS_PER_ROW) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); - return { m, m2 }; - } + return {m, m2}; + } - InterCTASync inter_cta_; - BlockStats block_stats_; + InterCTASync inter_cta_; + BlockStats block_stats_; - stats_t *w0_; - stats_t *w1_; - int bidn_; - int warp_n_; - int lane_; + stats_t *w0_; + stats_t *w1_; + int bidn_; + int warp_n_; + int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { - - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, + void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); + stats_t warp_stats = warp_stats_.compute(elts, warp_rn); + + // Each warp warp leader stores its stats + const auto warp_n = warp_stats_.reducer_.warp_n_; + const auto lane = warp_stats_.reducer_.lane_; + if (lane == 0) { + smem[warp_n] = warp_stats; } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - stats_t * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); - stats_t warp_stats = warp_stats_.compute(elts, warp_rn); - - //Each warp warp leader stores its stats - const auto warp_n = warp_stats_.reducer_.warp_n_; - const auto lane = warp_stats_.reducer_.lane_; - if( lane == 0 ) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with a single warp - static_assert(WARPS_N <= 32); - if(lane < WARPS_N){ - stats_t result = smem[lane]; - n = N * THREADS_PER_WARP; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); - - return { m, m2 }; + __syncthreads(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if (lane < WARPS_N) { + stats_t result = smem[lane]; + n = N * THREADS_PER_WARP; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); } - WarpStats warp_stats_; - stats_t * smem0_; - stats_t * smem1_; - bool use0_; + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return {m, m2}; + } + WarpStats warp_stats_; + stats_t *smem0_; + stats_t *smem1_; + bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) - { - } + enum { SMEM_BYTES = 0 }; - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + template + inline __device__ Stats(Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, + void *smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} - auto sum = Sum(); + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + auto sum = Sum(); - T m = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - m += elts[it]; - } - m = reducer_.allreduce(m, sum) * rn; - - T m2 = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - T diff = (elts[it] - m); - m2 += diff * diff; - } - m2 = reducer_.allreduce(m2, sum); + T m = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + m += elts[it]; + } + m = reducer_.allreduce(m, sum) * rn; - return {m, m2}; + T m2 = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + T diff = (elts[it] - m); + m2 += diff * diff; } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } - Reducer reducer_; + Reducer reducer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu index 373a217f4..68a93e148 100644 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" @@ -20,9 +19,8 @@ namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const half *pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const &input, const half *pad_mask, + float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); @@ -40,12 +38,9 @@ std::vector fwd_cuda(bool is_training, int heads, auto act_options = input.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *input_ptr = static_cast(input.data_ptr()); @@ -54,23 +49,19 @@ std::vector fwd_cuda(bool is_training, int heads, // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), k_seq_len, k_seq_len, - attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), k_seq_len, + k_seq_len, attn_batches * q_seq_len); } else { softmax_success = dispatch_additive_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), pad_mask, k_seq_len, + k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); + auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } @@ -80,8 +71,7 @@ std::vector fwd_cuda(bool is_training, int heads, return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, +torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, torch::Tensor const &dropout_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); @@ -99,15 +89,12 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len, stream); + static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), + 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); // backward pass is completely in-place return output_grads; } -} // namespace additive_mask_softmax_dropout -} // namespace fused_softmax -} // namespace multihead_attn +} // namespace additive_mask_softmax_dropout +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/dropout.cuh b/apex/contrib/csrc/multihead_attn/dropout.cuh index 6f3922a6e..09e64e515 100644 --- a/apex/contrib/csrc/multihead_attn/dropout.cuh +++ b/apex/contrib/csrc/multihead_attn/dropout.cuh @@ -12,13 +12,11 @@ namespace { constexpr int UNROLL = 4; -} // namespace +} // namespace template -__global__ void -apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, - uint8_t *mask, IndexType totalElements, accscalar_t p, - std::pair seeds) { +__global__ void apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, uint8_t *mask, + IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -26,10 +24,8 @@ apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { + ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; rand.x = rand.x <= p; @@ -55,10 +51,8 @@ apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, } template -__global__ void apex_dropout_add_kernel(scalar_t const *inputs, - scalar_t const *add_inputs, - scalar_t *outputs, uint8_t *mask, - IndexType totalElements, accscalar_t p, +__global__ void apex_dropout_add_kernel(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, + uint8_t *mask, IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -67,10 +61,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { + ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; @@ -89,8 +81,7 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; - outputs[li] = - static_cast(static_cast(add_src[ii]) + int1); + outputs[li] = static_cast(static_cast(add_src[ii]) + int1); mask[li] = (uint8_t)(&rand.x)[ii]; } } @@ -99,15 +90,12 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, } template -__global__ void apex_add_kernel(scalar_t const *inputs, - scalar_t const *add_inputs, scalar_t *outputs, +__global__ void apex_add_kernel(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, IndexType totalElements) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { + ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { @@ -128,16 +116,12 @@ __global__ void apex_add_kernel(scalar_t const *inputs, } template -__global__ void apex_masked_scale_kernel(scalar_t const *inputs, - scalar_t *outputs, uint8_t const *mask, - IndexType totalElements, - accscalar_t scale) { +__global__ void apex_masked_scale_kernel(scalar_t const *inputs, scalar_t *outputs, uint8_t const *mask, + IndexType totalElements, accscalar_t scale) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { + ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t msk[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { @@ -150,123 +134,87 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs, for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { - outputs[li] = static_cast(src[ii]) * scale * - static_cast(msk[ii]); + outputs[li] = static_cast(src[ii]) * scale * static_cast(msk[ii]); } } } } template -void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs, - uint8_t *mask, IndexType totalElements, +void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); + unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state - int64_t counter_offset = - ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; + int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); - rng_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( - counter_offset); + rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs(counter_offset); } - apex_fused_dropout_kernel - <<>>( - inputs, outputs, mask, totalElements, p, rng_engine_inputs); + apex_fused_dropout_kernel<<>>( + inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template -void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, - scalar_t *outputs, uint8_t *mask, +void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); + unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state - int64_t counter_offset = - ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; + int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); - rng_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( - counter_offset); + rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs(counter_offset); } - apex_dropout_add_kernel - <<>>( - inputs, add_inputs, outputs, mask, totalElements, p, - rng_engine_inputs); + apex_dropout_add_kernel<<>>( + inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template -void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, - scalar_t *outputs, IndexType totalElements) { +void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, IndexType totalElements) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); + unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); apex_add_kernel - <<>>( - inputs, add_inputs, outputs, totalElements); + <<>>(inputs, add_inputs, outputs, totalElements); C10_CUDA_CHECK(cudaGetLastError()); } template -void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs, - uint8_t const *mask, IndexType totalElements, +void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs, uint8_t const *mask, IndexType totalElements, accscalar_t scale) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); + unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); apex_masked_scale_kernel - <<>>( - inputs, outputs, mask, totalElements, scale); + <<>>(inputs, outputs, mask, totalElements, scale); C10_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index d56b80768..d535500aa 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" @@ -19,14 +18,10 @@ namespace multihead_attn { namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -57,26 +52,18 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, auto act_options = inputs_q.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor input_lin_q_results = - torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = - torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - void *k_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_kv_results.data_ptr()) + head_dim); + void *k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_kv_results.data_ptr()) + head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); @@ -88,105 +75,79 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Q Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, - embed_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), k_lin_results_ptr, - CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, - batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim_kv, batch_stride_kv, + static_cast(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta, + static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(softmax_results_ptr), + k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, + k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); + static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, + static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) + : static_cast(softmax_results.data_ptr()), + k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), + head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs}; + return {input_lin_q_results, input_lin_kv_results, softmax_results, dropout_results, + dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, + torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -222,20 +183,15 @@ std::vector bwd_cuda( at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = - torch::empty_like(input_lin_kv_results); + at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); - auto v_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()) + head_dim; - auto q_lin_grads_ptr = - static_cast(input_lin_q_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()); - auto v_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; + auto q_lin_grads_ptr = static_cast(input_lin_q_output_grads.data_ptr()); + auto k_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()); + auto v_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -245,124 +201,93 @@ std::vector bwd_cuda( TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim_q, batch_stride_q, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, + batch_stride_kv, static_cast(matmul2_grads.data_ptr()), k_seq_len, + k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q, + batch_stride_q, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_q_dim, static_cast(&beta), - static_cast(input_q_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), + static_cast(input_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, + static_cast(&alpha), static_cast(inputs_q.data_ptr()), + CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, + output_lin_q_dim, static_cast(&beta), + static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, - output_lin_kv_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, - output_lin_kv_dim, static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), + static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, - batches_kv, static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv, + static_cast(&alpha), static_cast(inputs_kv.data_ptr()), + CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, + output_lin_kv_dim, static_cast(&beta), + static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_q_grads, input_kv_grads, input_weight_q_grads, - input_weight_kv_grads, output_weight_grads}; + return {input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace encdec -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // end namespace encdec +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 10c9b8cef..c26aaab9d 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "layer_norm.cuh" #include "softmax.cuh" @@ -20,16 +19,11 @@ namespace multihead_attn { namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -66,28 +60,20 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); - torch::Tensor input_lin_q_results = - torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = - torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - void *k_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_kv_results.data_ptr()) + head_dim); + void *k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_kv_results.data_ptr()) + head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); @@ -99,96 +85,74 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( - static_cast(lyr_nrm_results.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(batches_q), // n1 - static_cast(embed_dim), // n2 + static_cast(lyr_nrm_results.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), + static_cast(lyr_nrm_invvar.data_ptr()), static_cast(inputs_q.data_ptr()), + static_cast(batches_q), // n1 + static_cast(embed_dim), // n2 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Q Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, // static_cast(inputs_q.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, - embed_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), k_lin_results_ptr, - CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, - batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim_kv, batch_stride_kv, + static_cast(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta, + static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(softmax_results_ptr), + k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, + k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); + static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - // static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, + static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) + : static_cast(softmax_results.data_ptr()), + // static_cast(dropout_results.data_ptr()), + k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), + head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -196,45 +160,31 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(outputs.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens_q, - (1.0f - dropout_prob)); + static_cast(inputs_q.data_ptr()), static_cast(outputs.data_ptr()), + static_cast(dropout_add_mask.data_ptr()), total_tokens_q, (1.0f - dropout_prob)); } else { - apex_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(outputs.data_ptr()), total_tokens_q); + apex_add_cuda(static_cast(output_lin_results.data_ptr()), + static_cast(inputs_q.data_ptr()), + static_cast(outputs.data_ptr()), total_tokens_q); } TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - dropout_add_mask, - outputs}; + return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_q_results, + input_lin_kv_results, softmax_results, dropout_results, dropout_mask, + matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, + torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, + float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -274,21 +224,16 @@ std::vector bwd_cuda( at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = - torch::empty_like(input_lin_kv_results); + at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); - auto v_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()) + head_dim; - auto q_lin_grads_ptr = - static_cast(input_lin_q_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()); - auto v_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; + auto q_lin_grads_ptr = static_cast(input_lin_q_output_grads.data_ptr()); + auto k_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()); + auto v_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -299,147 +244,111 @@ std::vector bwd_cuda( // Dropout Add Backward apex_masked_scale_cuda( - static_cast(output_grads.data_ptr()), - static_cast(dropout_add_grads.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens_q, - (1.0 / (1.0 - dropout_prob))); + static_cast(output_grads.data_ptr()), static_cast(dropout_add_grads.data_ptr()), + static_cast(dropout_add_mask.data_ptr()), total_tokens_q, (1.0 / (1.0 - dropout_prob))); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim_q, batch_stride_q, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, + batch_stride_kv, static_cast(matmul2_grads.data_ptr()), k_seq_len, + k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q, + batch_stride_q, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_q_dim, static_cast(&beta), + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), // static_cast(input_q_grads.data_ptr()), - static_cast(input_lin_q_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + static_cast(input_lin_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, + static_cast(&alpha), static_cast(inputs_q.data_ptr()), + CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, + output_lin_q_dim, static_cast(&beta), + static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, - output_lin_kv_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, - output_lin_kv_dim, static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), + static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, - batches_kv, static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv, + static_cast(&alpha), static_cast(inputs_kv.data_ptr()), + CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, + output_lin_kv_dim, static_cast(&beta), + static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( - static_cast(input_lin_q_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), inputs_q, - static_cast(batches_q), // n1 - static_cast(embed_dim), // n2 + static_cast(input_lin_q_grads.data_ptr()), static_cast(output_grads.data_ptr()), + static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), + inputs_q, + static_cast(batches_q), // n1 + static_cast(embed_dim), // n2 static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, - static_cast(input_q_grads.data_ptr()), - static_cast(lyr_nrm_gamma_grads.data_ptr()), - static_cast(lyr_nrm_beta_grads.data_ptr())); + static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast(input_q_grads.data_ptr()), + static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr())); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads, - output_weight_grads}; + return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, + input_weight_q_grads, input_weight_kv_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace encdec_norm_add -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // end namespace encdec_norm_add +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/layer_norm.cuh b/apex/contrib/csrc/multihead_attn/layer_norm.cuh index 16c1eeef4..c41f04306 100644 --- a/apex/contrib/csrc/multihead_attn/layer_norm.cuh +++ b/apex/contrib/csrc/multihead_attn/layer_norm.cuh @@ -1,7 +1,8 @@ #pragma once +#include #include #include -#include + #include namespace { @@ -16,8 +17,7 @@ __device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { } template -__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, - U &mu, U &sigma2, U &count) { +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, U &mu, U &sigma2, U &count) { U delta = muB - mu; U nA = count; U nB = countB; @@ -35,9 +35,8 @@ __device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, } template -__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, - const int n2, const int i1, U &mu, U &sigma2, - U *buf) { +__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, const int n2, const int i1, U &mu, + U &sigma2, U *buf) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensor is contiguous @@ -80,8 +79,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, U *ibuf = (U *)(ubuf + blockDim.y); for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_y = threadIdx.y - offset; ubuf[2 * wrt_y] = mu; ubuf[2 * wrt_y + 1] = sigma2; @@ -114,8 +112,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, } template <> -__device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, - const int n1, const int n2, const int i1, +__device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, const int n1, const int n2, const int i1, float &mu, float &sigma2, float *buf) { // Assumptions: // 1) blockDim.x == warpSize @@ -171,8 +168,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, float *ibuf = (float *)(ubuf + blockDim.y); for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_y = threadIdx.y - offset; ubuf[2 * wrt_y] = mu; ubuf[2 * wrt_y + 1] = sigma2; @@ -204,9 +200,18 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, } } -template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -template <> __device__ float rsqrt(float v) { return rsqrtf(v); } -template <> __device__ double rsqrt(double v) { return rsqrt(v); } +template +__device__ U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +__device__ float rsqrt(float v) { + return rsqrtf(v); +} +template <> +__device__ double rsqrt(double v) { + return rsqrt(v); +} // This is the un-specialized struct. Note that we prevent instantiation of // this struct by putting an undefined symbol in the function body so it won't @@ -223,15 +228,18 @@ template <> __device__ double rsqrt(double v) { return rsqrt(v); } // } // }; // https://github.com/NVIDIA/apex/issues/246 -template struct SharedMemory; -template <> struct SharedMemory { +template +struct SharedMemory; +template <> +struct SharedMemory { __device__ float *getPointer() { extern __shared__ float s_float[]; return s_float; } }; -template <> struct SharedMemory { +template <> +struct SharedMemory { __device__ double *getPointer() { extern __shared__ double s_double[]; return s_double; @@ -239,11 +247,9 @@ template <> struct SharedMemory { }; template -__global__ void -cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, - U *__restrict__ invvar, const T *__restrict__ vals, - const int n1, const int n2, const U epsilon, - const T *__restrict__ gamma, const T *__restrict__ beta) { +__global__ void cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, U *__restrict__ invvar, + const T *__restrict__ vals, const int n1, const int n2, const U epsilon, + const T *__restrict__ gamma, const T *__restrict__ beta) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensors are contiguous @@ -277,11 +283,10 @@ cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, } template -__device__ void cuLoadWriteStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const T *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar) { +__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const T *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean = mean[i1]; @@ -294,8 +299,7 @@ __device__ void cuLoadWriteStridedInputs( U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = - curr_dout * (curr_input - curr_mean) * curr_invvar; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; } else { warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); @@ -311,11 +315,10 @@ __device__ void cuLoadWriteStridedInputs( } template -__device__ void cuLoadAddStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const T *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar) { +__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const T *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean = mean[i1]; @@ -328,46 +331,38 @@ __device__ void cuLoadAddStridedInputs( U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += - curr_dout * (curr_input - curr_mean) * curr_invvar; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; } } } } template -__global__ void cuComputePartGradGammaBeta( - const T *__restrict__ dout, const T *__restrict__ input, const int n1, - const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, - U epsilon, U *part_grad_gamma, U *part_grad_beta) { - const int numsegs_n1 = - (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); +__global__ void cuComputePartGradGammaBeta(const T *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, + U epsilon, U *part_grad_gamma, U *part_grad_beta) { + const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; - const int i1_beg_plus_one = - (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; const int row_stride = blockDim.x + 1; const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); - const int thr_load_row_off = - (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; SharedMemory shared; - U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * - // blockDim.y + (blockDim.y - - // 1)*(blockDim.x/blockDim.y) elements + U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements U *warp_buf1 = (U *)buf; U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar); - for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; - i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar); + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, + dout, i1_end, n2, mean, invvar); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, + input, dout, i1_end, n2, mean, invvar); } __syncthreads(); // inter-warp reductions @@ -407,10 +402,8 @@ __global__ void cuComputePartGradGammaBeta( } template -__global__ void -cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, - const int part_size, const int n1, const int n2, - T *grad_gamma, T *grad_beta) { +__global__ void cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, const int part_size, + const int n1, const int n2, T *grad_gamma, T *grad_beta) { // sum partial gradients for gamma and beta SharedMemory shared; U *buf = shared.getPointer(); @@ -420,12 +413,9 @@ cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, int num_warp_reductions = part_size / blockDim.y; U sum_gamma = U(0); U sum_beta = U(0); - const U *part_grad_gamma_ptr = - part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U *part_grad_beta_ptr = - part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; - ++warp_offset) { + const U *part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U *part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; sum_beta += part_grad_beta_ptr[warp_offset * n2]; } @@ -455,13 +445,10 @@ cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, } } - template -__global__ void -cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, - const T *__restrict__ input, const int n1, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar, - U epsilon, const T *gamma, T *grad_input) { +__global__ void cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, + const T *__restrict__ input, const int n1, const int n2, const U *__restrict__ mean, + const U *__restrict__ invvar, U epsilon, const T *gamma, T *grad_input) { for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); @@ -479,16 +466,14 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss * static_cast(gamma[l + k]); - sum_loss2 += - c_loss * static_cast(gamma[l + k]) * (c_h - c_mean) * c_invvar; + sum_loss2 += c_loss * static_cast(gamma[l + k]) * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss * static_cast(gamma[l]); - sum_loss2 += - c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; + sum_loss2 += c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; } } else { int l = 4 * thrx; @@ -573,24 +558,19 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, } template -void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1, - int n2, double epsilon, const T *gamma, const T *beta) { +void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1, int n2, double epsilon, const T *gamma, + const T *beta) { auto stream = at::cuda::getCurrentCUDAStream().stream(); const dim3 threads(32, 4, 1); - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyLayerNorm<<>>( - output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); + int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); } template -void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean, - const U *invvar, const at::Tensor &input, int n1, - int n2, const T *gamma, const T *beta, - double epsilon, T *grad_input, T *grad_gamma, +void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean, const U *invvar, const at::Tensor &input, + int n1, int n2, const T *gamma, const T *beta, double epsilon, T *grad_input, T *grad_gamma, T *grad_beta) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -599,38 +579,31 @@ void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean, const int part_size = 16; const dim3 threads2(32, 4, 1); const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; at::Tensor part_grad_gamma = at::empty( - {part_size, n2}, - input.options().dtype(input.scalar_type() == at::ScalarType::Half - ? at::ScalarType::Float - : input.scalar_type())); + {part_size, n2}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ? at::ScalarType::Float + : input.scalar_type())); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<>>( - dout, static_cast(input.data_ptr()), n1, n2, mean, invvar, - U(epsilon), static_cast(part_grad_gamma.data_ptr()), - static_cast(part_grad_beta.data_ptr())); + dout, static_cast(input.data_ptr()), n1, n2, mean, invvar, U(epsilon), + static_cast(part_grad_gamma.data_ptr()), static_cast(part_grad_beta.data_ptr())); const dim3 threads3(32, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - static_cast(part_grad_gamma.data_ptr()), - static_cast(part_grad_beta.data_ptr()), part_size, n1, n2, - grad_gamma, grad_beta); + cuComputeGradGammaBeta<<>>(static_cast(part_grad_gamma.data_ptr()), + static_cast(part_grad_beta.data_ptr()), + part_size, n1, n2, grad_gamma, grad_beta); } // compute grad_input - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 threads1(32, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, dout_resid, static_cast(input.data_ptr()), n1, n2, mean, - invvar, U(epsilon), gamma, grad_input); + cuComputeGradInput<<>>(dout, dout_resid, static_cast(input.data_ptr()), n1, + n2, mean, invvar, U(epsilon), gamma, grad_input); } -} // namespace +} // namespace diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu index f9a031d53..508d82aaf 100644 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" @@ -18,9 +17,7 @@ namespace multihead_attn { namespace fused_softmax { namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const uint8_t *pad_mask, +std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const &input, const uint8_t *pad_mask, float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; @@ -39,12 +36,9 @@ std::vector fwd_cuda(bool is_training, int heads, auto act_options = input.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *input_ptr = static_cast(input.data_ptr()); @@ -53,23 +47,19 @@ std::vector fwd_cuda(bool is_training, int heads, // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), k_seq_len, k_seq_len, - attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), k_seq_len, + k_seq_len, attn_batches * q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), pad_mask, k_seq_len, + k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); + auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } @@ -79,10 +69,8 @@ std::vector fwd_cuda(bool is_training, int heads, return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - const uint8_t *padding_mask, float dropout_prob) { +torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, + torch::Tensor const &dropout_mask, const uint8_t *padding_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; @@ -100,25 +88,20 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, // Softmax Grad if (padding_mask == nullptr) { dispatch_masked_scale_softmax_backward_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), + static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, + static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); } else { - dispatch_masked_scale_softmax_backward_masked_out_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), + dispatch_masked_scale_softmax_backward_masked_out_stream( + static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - static_cast(padding_mask), 1.0 / (1.0 - dropout_prob), - k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream); + static_cast(dropout_mask.data_ptr()), static_cast(padding_mask), + 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream); } // backward pass is completely in-place return output_grads; } -} // namespace mask_softmax_dropout -} // namespace fused_softmax -} // namespace multihead_attn +} // namespace mask_softmax_dropout +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index 98ef4b1f5..2e5d168b7 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -1,189 +1,138 @@ -#include - #include #include +#include -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const half *pad_mask, float dropout_prob); +std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const &input, const half *pad_mask, + float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, +torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, torch::Tensor const &dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, - torch::Tensor const &input, - torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const &input, + torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, "Only BYTE is supported"); } - return fwd_cuda(is_training, heads, input, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, + return fwd_cuda(is_training, heads, input, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, +torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, torch::Tensor const &dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); // TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, // "Only BYTE is supported"); - return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, - dropout_prob); + return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, dropout_prob); } -} // namespace additive_mask_softmax_dropout +} // namespace additive_mask_softmax_dropout namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const uint8_t *pad_mask, +std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const &input, const uint8_t *pad_mask, float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - const uint8_t *padding_mask, float dropout_prob); +torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, + torch::Tensor const &dropout_mask, const uint8_t *padding_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, - torch::Tensor const &input, - torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const &input, + torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda(is_training, heads, input, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, + return fwd_cuda(is_training, heads, input, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - torch::Tensor const &padding_mask, float dropout_prob) { +torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, torch::Tensor const &softmax_results, + torch::Tensor const &dropout_mask, torch::Tensor const &padding_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); // TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, // "Only BYTE is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, - use_mask - ? static_cast(padding_mask.data_ptr()) - : nullptr, - dropout_prob); + use_mask ? static_cast(padding_mask.data_ptr()) : nullptr, dropout_prob); } -} // end namespace mask_softmax_dropout -} // end namespace fused_softmax +} // end namespace mask_softmax_dropout +} // end namespace fused_softmax namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob); +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, + torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, float dropout_prob); -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, + torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, + torch::Tensor const &output_weights, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, - input_weights_q, input_weights_kv, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, + output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, + torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -197,76 +146,52 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_q_results, input_lin_kv_results, - inputs_q, inputs_kv, input_weights_q, input_weights_kv, - output_weights, dropout_mask, dropout_prob); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, + input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, + dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // end namespace encdec +} // end namespace cublas_gemmex +} // end namespace encdec namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, + torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, float dropout_prob); -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, + torch::Tensor const &lyr_nrm_gamma_weights, torch::Tensor const &lyr_nrm_beta_weights, + torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, + torch::Tensor const &output_weights, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); @@ -275,48 +200,34 @@ fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, - lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, - input_weights_kv, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, + input_weights_q, input_weights_kv, output_weights, + use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_q_results, torch::Tensor const &input_lin_kv_results, + torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, + torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, + torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights_q, + torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, + float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -336,107 +247,72 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_add_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, - "Only FLOAT is supported"); - TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, - "Only FLOAT is supported"); - TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_q_results, input_lin_kv_results, - lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, - inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, - input_weights_q, input_weights_kv, output_weights, + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); + TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); + TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, + input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, + lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_add_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // end namespace encdec_norm_add +} // end namespace cublas_gemmex +} // end namespace encdec_norm_add namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, float dropout_prob); -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda( - use_time_mask, is_training, heads, inputs, input_weights, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, + use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -447,89 +323,64 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, inputs, input_weights, - output_weights, dropout_mask, dropout_prob); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, + input_weights, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // end namespace self +} // end namespace cublas_gemmex +} // end namespace self namespace self_bias { namespace cublas_gemmex { -std::vector -fwd_cuda(bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &input_biases, torch::Tensor const &output_biases, + const uint8_t *pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + // torch::Tensor const& input_biases, + // torch::Tensor const& output_biases, + torch::Tensor const &dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &input_biases, + torch::Tensor const &output_biases, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, - output_weights, input_biases, output_biases, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, + use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -540,92 +391,68 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, inputs, input_weights, - output_weights, dropout_mask, dropout_prob); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, + input_weights, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // namespace self_bias +} // end namespace cublas_gemmex +} // namespace self_bias namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - torch::Tensor const &input_biases, - torch::Tensor const &output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &input_biases, torch::Tensor const &output_biases, const half *pad_mask, float dropout_prob); -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - // torch::Tensor const& softmax_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, torch::Tensor const &pad_mask, - float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, + // torch::Tensor const& softmax_results, + torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + // torch::Tensor const& input_biases, + // torch::Tensor const& output_biases, + torch::Tensor const &dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &input_biases, + torch::Tensor const &output_biases, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(use_mask, "no mask is not supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, - "Only Half is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, "Only Half is supported"); } - return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, - output_weights, input_biases, output_biases, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, + use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &bmm1_results, + torch::Tensor const &pad_mask, torch::Tensor const &input_lin_results, + torch::Tensor const &inputs, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -635,104 +462,72 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - bmm1_results, pad_mask, input_lin_results, inputs, - input_weights, output_weights, dropout_mask, dropout_prob); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results, + inputs, input_weights, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // namespace self_bias_additive_mask +} // end namespace cublas_gemmex +} // namespace self_bias_additive_mask namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &pad_mask, float dropout_prob) { + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, const uint8_t *pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &lyr_nrm_results, + torch::Tensor const &lyr_nrm_mean, torch::Tensor const &lyr_nrm_invvar, + torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + torch::Tensor const &dropout_add_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); - TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); + TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } - return fwd_cuda( - use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, input_weights, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, - dropout_prob); + return fwd_cuda(use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, + output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob) { +std::vector bwd(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &lyr_nrm_results, + torch::Tensor const &lyr_nrm_mean, torch::Tensor const &lyr_nrm_invvar, + torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + torch::Tensor const &dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -749,63 +544,40 @@ bwd(int heads, torch::Tensor const &output_grads, TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_add_mask.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, - "Only FLOAT is supported"); - TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, - "Only FLOAT is supported"); - TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, - "Only HALF is supported"); - TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, lyr_nrm_results, - lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, input_weights, output_weights, - dropout_mask, dropout_add_mask, dropout_prob); + TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); + TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); + TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); + TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); + + return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, + lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, + input_weights, output_weights, dropout_mask, dropout_add_mask, dropout_prob); } -} // end namespace cublas_gemmex -} // end namespace self_norm_add -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // end namespace self_norm_add +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("additive_mask_softmax_dropout_forward", - &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", - py::call_guard()); - m.def("additive_mask_softmax_dropout_backward", - &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", - py::call_guard()); + m.def("additive_mask_softmax_dropout_forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, + "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); + m.def("additive_mask_softmax_dropout_backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, + "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", - py::call_guard()); + "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", - py::call_guard()); + "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.", py::call_guard()); m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::cublas_gemmex::bwd, @@ -813,27 +585,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); - m.def( - "encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.", - py::call_guard()); - m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, - "Self Multihead Attention Forward.", py::call_guard()); - m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, - "Self Multihead Attention Backward.", py::call_guard()); + m.def("encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, + "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.", + py::call_guard()); + m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.", + py::call_guard()); + m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.", + py::call_guard()); m.def("self_attn_bias_forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.", py::call_guard()); m.def("self_attn_bias_backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.", py::call_guard()); m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.", py::call_guard()); - m.def("self_attn_bias_additive_mask_backward", - &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, + m.def("self_attn_bias_additive_mask_backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.", py::call_guard()); m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Backward.", py::call_guard()); + "Self Multihead Attention Plus Layer Norm and Residual Add Backward.", + py::call_guard()); } #undef CHECK_CUDA diff --git a/apex/contrib/csrc/multihead_attn/philox.cuh b/apex/contrib/csrc/multihead_attn/philox.cuh index d2076ab5a..93529f91b 100644 --- a/apex/contrib/csrc/multihead_attn/philox.cuh +++ b/apex/contrib/csrc/multihead_attn/philox.cuh @@ -4,20 +4,19 @@ namespace { class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) : STATE(0) { - //key.x = (unsigned int)seed; - //key.y = (unsigned int)(seed >> 32); - //counter = make_uint4(0, 0, 0, 0); - //counter.z = (unsigned int)(subsequence); - //counter.w = (unsigned int)(subsequence >> 32); - //STATE = 0; - //incr_n(offset / 4); + public: + __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) + : STATE(0) { + // key.x = (unsigned int)seed; + // key.y = (unsigned int)(seed >> 32); + // counter = make_uint4(0, 0, 0, 0); + // counter.z = (unsigned int)(subsequence); + // counter.w = (unsigned int)(subsequence >> 32); + // STATE = 0; + // incr_n(offset / 4); - key = reinterpret_cast(seed); - ull2 * tmp = reinterpret_cast(&counter); + key = reinterpret_cast(seed); + ull2 *tmp = reinterpret_cast(&counter); tmp->x = offset / 4; tmp->y = subsequence; } @@ -46,10 +45,10 @@ public: return output; } -private: + private: struct ull2 { - uint64_t x; - uint64_t y; + uint64_t x; + uint64_t y; }; uint4 counter; uint4 output; @@ -59,56 +58,45 @@ private: unsigned int nlo = (unsigned int)(n); unsigned int nhi = (unsigned int)(n >> 32); counter.x += nlo; - if (counter.x < nlo) - nhi++; + if (counter.x < nlo) nhi++; counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; + if (nhi <= counter.y) return; + if (++counter.z) return; ++counter.w; } - __device__ uint4 incr128 (uint4 ctr) - { + __device__ uint4 incr128(uint4 ctr) { uint4 res; - asm ("add.cc.u32 %0, %4, %8;\n\t" - "addc.cc.u32 %1, %5, %9;\n\t" - "addc.cc.u32 %2, %6, %10;\n\t" - "addc.u32 %3, %7, %11;\n\t" - : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) - : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), - "n"(1), "n"(0), "n"(0), "n"(0)); + asm("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), "n"(1), "n"(0), "n"(0), "n"(0)); return res; } - __device__ inline void incr() { - counter = incr128(counter); - } - __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, - unsigned int *result_high) { + __device__ inline void incr() { counter = incr128(counter); } + __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, unsigned int *result_high) { *result_high = __umulhi(a, b); return a * b; } - __device__ uint2 mulhilo32_v2 (unsigned int a, unsigned int b) - { + __device__ uint2 mulhilo32_v2(unsigned int a, unsigned int b) { uint2 *res; unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); + asm("mul.wide.u32 %0, %1, %2;\n\t" : "=l"(tmp) : "r"(a), "r"(b)); + res = (uint2 *)(&tmp); return *res; } __device__ inline uint4 single_round(uint4 ctr, uint2 key) { - //unsigned int hi0; - //unsigned int hi1; - //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); - //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); - //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + // unsigned int hi0; + // unsigned int hi1; + // unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + // unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + // uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; return ret; } static const unsigned long kPhilox10A = 0x9E3779B9; @@ -119,8 +107,7 @@ private: // Inverse of 2^32. constexpr float M_RAN_INVM32 = 2.3283064e-10f; __device__ __inline__ float4 uniform4(uint4 x) { - return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, - x.w * M_RAN_INVM32); + return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, x.w * M_RAN_INVM32); } -} // namespace +} // namespace diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index c3a090c25..33fa024b8 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" @@ -19,12 +18,9 @@ namespace multihead_attn { namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - torch::Tensor const &input_biases, - torch::Tensor const &output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &input_biases, torch::Tensor const &output_biases, const half *pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); @@ -53,24 +49,17 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor bmm1_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); + torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); + void *k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *bmm1_results_ptr = static_cast(bmm1_results.data_ptr()); @@ -84,79 +73,60 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta_one), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), + q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta_zero, static_cast(bmm1_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim, batch_stride, + static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta_zero, + static_cast(bmm1_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (is_training) { - softmax_success = - dispatch_additive_masked_softmax_dropout( - reinterpret_cast(dropout_results_ptr), - (is_training) - ? reinterpret_cast(dropout_mask.data_ptr()) - : nullptr, - reinterpret_cast(bmm1_results_ptr), pad_mask, - attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len, - attn_batches * q_seq_len, attn_batches * q_seq_len / sequences, - 1.0f - dropout_prob, stream); + softmax_success = dispatch_additive_masked_softmax_dropout( + reinterpret_cast(dropout_results_ptr), + (is_training) ? reinterpret_cast(dropout_mask.data_ptr()) : nullptr, + reinterpret_cast(bmm1_results_ptr), pad_mask, attn_batches * q_seq_len * q_seq_len, k_seq_len, + k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences, 1.0f - dropout_prob, stream); } else { softmax_success = dispatch_additive_masked_softmax( - reinterpret_cast( - dropout_results_ptr), // this is actually softmax results, but - // making it consistent for the next function - reinterpret_cast(bmm1_results_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, + reinterpret_cast(dropout_results_ptr), // this is actually softmax results, but + // making it consistent for the next function + reinterpret_cast(bmm1_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } // Matmul2 gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(dropout_results.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, beta_zero, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), + lead_dim, batch_stride, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta_zero, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); outputs.copy_(output_biases); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta_one), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta_one), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_lin_results, bmm1_results, dropout_results, - dropout_mask, matmul2_results, outputs}; + return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &bmm1_results, + torch::Tensor const &pad_mask, torch::Tensor const &input_lin_results, + torch::Tensor const &inputs, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -188,16 +158,12 @@ std::vector bwd_cuda( at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; + auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; + auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; + auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -207,100 +173,74 @@ std::vector bwd_cuda( TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_recompute( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(bmm1_results.data_ptr()), - reinterpret_cast(pad_mask.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(bmm1_results.data_ptr()), reinterpret_cast(pad_mask.data_ptr()), + static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len / sequences, attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(input_lin_output_grads.data_ptr()), - // static_cast(q_lin_grads_ptr), - CUDA_R_16F, output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(input_lin_output_grads.data_ptr()), + // static_cast(q_lin_grads_ptr), + CUDA_R_16F, output_lin_dim, static_cast(&beta), + static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, + // CUBLAS_GEMM_ALGO10_TENSOR_OP)); + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - auto input_bias_grads = - input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); + handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), + CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_grads, input_weight_grads, output_weight_grads, - input_bias_grads, output_bias_grads}; + return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } -} // end namespace cublas_gemmex -} // namespace self_bias_additive_mask -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // namespace self_bias_additive_mask +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index e9b3e9e9f..f5bcc8e78 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" @@ -19,12 +18,10 @@ namespace multihead_attn { namespace self_bias { namespace cublas_gemmex { -std::vector -fwd_cuda(bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, const uint8_t *pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &input_biases, torch::Tensor const &output_biases, + const uint8_t *pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -52,24 +49,17 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); + void *k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); @@ -82,88 +72,71 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta_one), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), + q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta_zero, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim, batch_stride, + static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta_zero, + static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(softmax_results_ptr), + k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, + k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); + auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta_zero, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) + : static_cast(softmax_results.data_ptr()), + k_seq_len, k_seq_len * q_seq_len, beta_zero, static_cast(matmul2_results.data_ptr()), + head_dim * attn_batches, head_dim, attn_batches); outputs.copy_(output_biases); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta_one), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta_one), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_lin_results, softmax_results, dropout_results, - dropout_mask, matmul2_results, outputs}; + return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -195,16 +168,12 @@ std::vector bwd_cuda( at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; + auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; + auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; + auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -214,99 +183,73 @@ std::vector bwd_cuda( TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len, stream); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), + 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(input_lin_output_grads.data_ptr()), - // static_cast(q_lin_grads_ptr), - CUDA_R_16F, output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(input_lin_output_grads.data_ptr()), + // static_cast(q_lin_grads_ptr), + CUDA_R_16F, output_lin_dim, static_cast(&beta), + static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, + // CUBLAS_GEMM_ALGO10_TENSOR_OP)); + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - auto input_bias_grads = - input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); + handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), + CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_grads, input_weight_grads, output_weight_grads, - input_bias_grads, output_bias_grads}; + return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } -} // end namespace cublas_gemmex -} // namespace self_bias -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // namespace self_bias +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index 9701da7cc..12a423390 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" @@ -19,12 +18,9 @@ namespace multihead_attn { namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + const uint8_t *pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -51,24 +47,17 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); + void *k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); @@ -80,85 +69,67 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim, batch_stride, + static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta, + static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(softmax_results_ptr), + k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, + k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); + static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) + : static_cast(softmax_results.data_ptr()), + k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), + head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_lin_results, softmax_results, dropout_results, - dropout_mask, matmul2_results, outputs}; + return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &inputs, + torch::Tensor const &input_weights, torch::Tensor const &output_weights, + torch::Tensor const &dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -190,16 +161,12 @@ std::vector bwd_cuda( at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; + auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; + auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; + auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -209,99 +176,73 @@ std::vector bwd_cuda( TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), + static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), + static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), + CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace self -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // end namespace self +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index 7a6ec15cd..e7052d8d1 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -1,16 +1,15 @@ -#include -#include -#include - +#include +#include #include #include #include #include - -#include -#include +#include #include +#include +#include + #include "dropout.cuh" #include "layer_norm.cuh" #include "softmax.cuh" @@ -20,14 +19,10 @@ namespace multihead_attn { namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, const uint8_t *pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -60,26 +55,19 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); + torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); + torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor output_lin_results = torch::empty_like(inputs, act_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); + void *k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); + void *v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); @@ -91,120 +79,95 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( - static_cast(lyr_nrm_results.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(batches), // n1 - static_cast(embed_dim), // n2 + static_cast(lyr_nrm_results.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), + static_cast(lyr_nrm_invvar.data_ptr()), static_cast(inputs.data_ptr()), + static_cast(batches), // n1 + static_cast(embed_dim), // n2 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, + handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, // static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), + q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, + static_cast(k_lin_results_ptr), lead_dim, batch_stride, + static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta, + static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), + reinterpret_cast(softmax_results_ptr), + k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); + reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, + k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); + static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - // static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) + : static_cast(softmax_results.data_ptr()), + // static_cast(dropout_results.data_ptr()), + k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), + head_dim * attn_batches, head_dim, attn_batches); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_results.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(outputs.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens, + static_cast(output_lin_results.data_ptr()), static_cast(inputs.data_ptr()), + static_cast(outputs.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens, (1.0f - dropout_prob)); } else { - apex_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(outputs.data_ptr()), total_tokens); + apex_add_cuda(static_cast(output_lin_results.data_ptr()), + static_cast(inputs.data_ptr()), + static_cast(outputs.data_ptr()), total_tokens); } TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results, - softmax_results, dropout_results, dropout_mask, matmul2_results, - dropout_add_mask, outputs}; + return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results, softmax_results, + dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob) { +std::vector bwd_cuda(int heads, torch::Tensor const &output_grads, torch::Tensor const &matmul2_results, + torch::Tensor const &dropout_results, torch::Tensor const &softmax_results, + torch::Tensor const &input_lin_results, torch::Tensor const &lyr_nrm_results, + torch::Tensor const &lyr_nrm_mean, torch::Tensor const &lyr_nrm_invvar, + torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, + torch::Tensor const &lyr_nrm_beta_weights, torch::Tensor const &input_weights, + torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, + torch::Tensor const &dropout_add_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -241,16 +204,12 @@ std::vector bwd_cuda( torch::Tensor input_lin_grads = torch::empty_like(inputs); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; + auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; + auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; + auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; + auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; @@ -261,125 +220,94 @@ std::vector bwd_cuda( // Dropout Add Backward apex_masked_scale_cuda( - static_cast(output_grads.data_ptr()), - static_cast(dropout_add_grads.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens, - (1.0 / (1.0 - dropout_prob))); + static_cast(output_grads.data_ptr()), static_cast(dropout_add_grads.data_ptr()), + static_cast(dropout_add_mask.data_ptr()), total_tokens, (1.0 / (1.0 - dropout_prob))); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_lin_grads.data_ptr()), + CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK( + cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, + embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, + static_cast(v_lin_results_ptr), lead_dim, batch_stride, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, + static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, + static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, + static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, + v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); + static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, + beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), + handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), + static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), // static_cast(input_grads.data_ptr()), - static_cast(input_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, + static_cast(input_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), + handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), // static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, + static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), + static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( - static_cast(input_lin_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), inputs, - static_cast(batches), // n1 - static_cast(embed_dim), // n2 + static_cast(input_lin_grads.data_ptr()), static_cast(output_grads.data_ptr()), + static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), + inputs, + static_cast(batches), // n1 + static_cast(embed_dim), // n2 static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, - static_cast(input_grads.data_ptr()), - static_cast(lyr_nrm_gamma_grads.data_ptr()), - static_cast(lyr_nrm_beta_grads.data_ptr())); + static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast(input_grads.data_ptr()), + static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr())); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, - input_weight_grads, output_weight_grads}; + return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace self_norm_add -} // end namespace multihead_attn +} // end namespace cublas_gemmex +} // end namespace self_norm_add +} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/softmax.cuh b/apex/contrib/csrc/multihead_attn/softmax.cuh index 254e84210..87ddd35b5 100644 --- a/apex/contrib/csrc/multihead_attn/softmax.cuh +++ b/apex/contrib/csrc/multihead_attn/softmax.cuh @@ -1,8 +1,10 @@ #pragma once -#include "philox.cuh" -#include #include +#include + +#include "philox.cuh" + #ifdef OLD_GENERATOR_PATH #include #else @@ -10,27 +12,25 @@ #endif #include +#include +#include + #include #include -#include #include -#include namespace { template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template -__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, - const uint8_t *src); +__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src); template -__device__ __inline__ void apply_additive_mask(Datatype *dst, - const Datatype *additive_mask); +__device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask); template <> -__device__ __inline__ void copy_vector<__half, 1>(__half *dst, - const __half *src) { +__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } @@ -40,39 +40,33 @@ __device__ __inline__ void copy_vector(float *dst, const float *src) { } template <> -__device__ __inline__ void copy_vector<__half, 4>(__half *dst, - const __half *src) { +__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2 *)dst) = *((float2 *)src); } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *((half2 *)dst) = *((half2 *)src); } template <> -__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, - const uint8_t *src) { +__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) { if (*src == 1) { *dst = value; } } template <> -__device__ __inline__ void -apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) { +__device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) { *dst += *additive_mask; } template <> -__device__ __inline__ void -apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { +__device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { *dst += *additive_mask; *(dst + 1) += *(additive_mask + 1); *(dst + 2) += *(additive_mask + 2); @@ -87,11 +81,9 @@ apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void softmax_warp_forward(input_t *dst, const output_t *src, - int batch_size, int stride, - int element_count) { +template +__global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batch_size, int stride, int element_count) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -99,8 +91,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -117,13 +108,11 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); + elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { - copy_vector( - &elements_input[i][it], src + i * element_count + it * WARP_SIZE); + copy_vector(&elements_input[i][it], src + i * element_count + it * WARP_SIZE); } } } @@ -150,8 +139,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -193,8 +181,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -204,8 +191,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } @@ -218,13 +204,10 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using softmax_forward_func = void (*)(input_t *dst, const output_t *src, - int batch_size, int stride, - int element_count); +using softmax_forward_func = void (*)(input_t *dst, const output_t *src, int batch_size, int stride, int element_count); template -bool warp_softmax_kernel(int log2_elements, int &warp_size, - int &batches_per_warp, +bool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; @@ -234,61 +217,59 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size, batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = &softmax_warp_forward; - break; - case 1: // 2 - kernel = &softmax_warp_forward; - break; - case 2: // 4 - kernel = &softmax_warp_forward; - break; - case 3: // 8 - kernel = &softmax_warp_forward; - break; - case 4: // 16 - kernel = &softmax_warp_forward; - break; - case 5: // 32 - kernel = &softmax_warp_forward; - break; - case 6: // 64 - kernel = &softmax_warp_forward; - break; - case 7: // 128 - kernel = &softmax_warp_forward; - break; - case 8: // 256 - kernel = &softmax_warp_forward; - break; - case 9: // 512 - kernel = &softmax_warp_forward; - break; - case 10: // 1024 - kernel = &softmax_warp_forward; - break; - default: - return false; + case 0: // 1 + kernel = &softmax_warp_forward; + break; + case 1: // 2 + kernel = &softmax_warp_forward; + break; + case 2: // 4 + kernel = &softmax_warp_forward; + break; + case 3: // 8 + kernel = &softmax_warp_forward; + break; + case 4: // 16 + kernel = &softmax_warp_forward; + break; + case 5: // 32 + kernel = &softmax_warp_forward; + break; + case 6: // 64 + kernel = &softmax_warp_forward; + break; + case 7: // 128 + kernel = &softmax_warp_forward; + break; + case 8: // 256 + kernel = &softmax_warp_forward; + break; + case 9: // 512 + kernel = &softmax_warp_forward; + break; + case 10: // 1024 + kernel = &softmax_warp_forward; + break; + default: + return false; } return true; } template -bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, - int softmax_elements_stride, int batch_count) { +bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, + int batch_count) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } @@ -304,30 +285,28 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - dst, src, batch_count, softmax_elements_stride, softmax_elements); + kernel<<>>(dst, src, batch_count, softmax_elements_stride, + softmax_elements); return true; } return false; } -template -__global__ void additive_masked_softmax_dropout_warp_forward_vec4( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { - +template +__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, + const input_t *src, const input_t *pad_mask, + int batch_size, int stride, int element_count, + int pad_batch_stride, at::PhiloxCudaState philox_args, + float p) { assert(ELEMENTS_PER_LDG_STG == 4); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x; + int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; acc_t pinv = acc_t(1) / p; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -343,8 +322,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const half *curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { @@ -358,12 +336,10 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); + copy_vector(&elements_input[i][it], src + itr_idx); apply_additive_mask( &elements_input[i][it], - curr_mask + - itr_jmp); //(__half)-std::numeric_limits::infinity() + curr_mask + itr_jmp); //(__half)-std::numeric_limits::infinity() } } } @@ -391,8 +367,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -435,8 +410,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( float4 rand_num; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -446,8 +420,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( rands[i][it + 1] = (rand_num.y <= p) > 0.5; rands[i][it + 2] = (rand_num.z <= p) > 0.5; rands[i][it + 3] = (rand_num.w <= p) > 0.5; - copy_vector( - dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]); + copy_vector(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]); } } } @@ -455,8 +428,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -464,11 +436,9 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = rands[i][it + element] * - (pinv * (elements[i][it + element] / sum[i])); + out[element] = rands[i][it + element] * (pinv * (elements[i][it + element] / sum[i])); } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; @@ -477,22 +447,20 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( } } -template -__global__ void additive_masked_softmax_dropout_warp_forward( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { +template +__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, + const input_t *pad_mask, int batch_size, int stride, + int element_count, int pad_batch_stride, + at::PhiloxCudaState philox_args, float p) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x; + int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; acc_t pinv = acc_t(1) / p; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -508,8 +476,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = - ((first_batch + i) / pad_batch_stride) * stride + local_idx; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + local_idx; const half *curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += 1) { @@ -524,8 +491,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); - apply_additive_mask(&elements_input[i][it], - curr_mask + itr_jmp); + apply_additive_mask(&elements_input[i][it], curr_mask + itr_jmp); } } } @@ -553,8 +519,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -598,8 +563,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += 1) { int element_index = local_idx + it * WARP_SIZE; @@ -615,13 +579,10 @@ __global__ void additive_masked_softmax_dropout_warp_forward( softmax_out[element] = (elements[i][it + element] / sum[i]); rand_ptr[element] = rand_ptr[element] <= p; out[element] = rand_ptr[element] * pinv * softmax_out[element]; - dropout_mask_temp[element] = - rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f + dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f } copy_vector(dst + i * element_count + it * WARP_SIZE, out); - copy_vector(dropout_mask + i * element_count + - it * WARP_SIZE, - dropout_mask_temp); + copy_vector(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp); } else { break; @@ -635,16 +596,15 @@ __global__ void additive_masked_softmax_dropout_warp_forward( // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using additive_masked_softmax_dropout_forward_func = void (*)( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p); +using additive_masked_softmax_dropout_forward_func = void (*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, + const input_t *pad_mask, int batch_size, int stride, + int element_count, int pad_batch_stride, + at::PhiloxCudaState philox_args, float p); template bool warp_additive_masked_softmax_dropout_kernel( int element_count, int log2_elements, int &warp_size, int &batches_per_warp, - additive_masked_softmax_dropout_forward_func - &kernel) { + additive_masked_softmax_dropout_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -653,108 +613,82 @@ bool warp_additive_masked_softmax_dropout_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; bool flag_vec4 = (element_count % 4 == 0); switch (log2_elements) { - case 0: // 1 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 1: // 2 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 2: // 4 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 3: // 8 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 4: // 16 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 5: // 32 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 6: // 64 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 7: // 128 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 2, 4, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 8: // 256 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 8, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 9: // 512 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 16, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 10: // 1024 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 32, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 11: // 2048 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 64, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - default: - return false; + case 0: // 1 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 1: // 2 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 2: // 4 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 3: // 8 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 4: // 16 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 5: // 32 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 6: // 64 + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 7: // 128 + if (flag_vec4) + kernel = &additive_masked_softmax_dropout_warp_forward_vec4; + else + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 8: // 256 + if (flag_vec4) + kernel = &additive_masked_softmax_dropout_warp_forward_vec4; + else + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 9: // 512 + if (flag_vec4) + kernel = &additive_masked_softmax_dropout_warp_forward_vec4; + else + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 10: // 1024 + if (flag_vec4) + kernel = &additive_masked_softmax_dropout_warp_forward_vec4; + else + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + case 11: // 2048 + if (flag_vec4) + kernel = &additive_masked_softmax_dropout_warp_forward_vec4; + else + kernel = &additive_masked_softmax_dropout_warp_forward; + break; + default: + return false; } return true; } template -bool dispatch_additive_masked_softmax_dropout( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int totalElements, int softmax_elements, - int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, - cudaStream_t streamid) // p is the probability to keep, not drop +bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, + const input_t *pad_mask, int totalElements, int softmax_elements, + int softmax_elements_stride, int batch_count, int pad_batch_stride, + float p, + cudaStream_t streamid) // p is the probability to keep, not drop { - if (softmax_elements == 0) { return true; } else if (softmax_elements <= 2048) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; - additive_masked_softmax_dropout_forward_func - kernel; + additive_masked_softmax_dropout_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_dropout_kernel( - softmax_elements, log2_elements, warp_size, batches_per_warp, - kernel)) { + if (!warp_additive_masked_softmax_dropout_kernel(softmax_elements, log2_elements, + warp_size, batches_per_warp, kernel)) { return false; } @@ -765,8 +699,7 @@ bool dispatch_additive_masked_softmax_dropout( int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; c10::optional gen_; - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1); at::PhiloxCudaState rng_engine_inputs; { @@ -778,9 +711,8 @@ bool dispatch_additive_masked_softmax_dropout( dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride, rng_engine_inputs, p); + kernel<<>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, + softmax_elements, pad_batch_stride, rng_engine_inputs, p); return true; } return false; @@ -790,11 +722,11 @@ bool dispatch_additive_masked_softmax_dropout( // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void additive_masked_softmax_warp_forward( - input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, - int stride, int element_count, int pad_batch_stride) { +template +__global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_t *src, const input_t *pad_mask, + int batch_size, int stride, int element_count, + int pad_batch_stride) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -802,8 +734,7 @@ __global__ void additive_masked_softmax_warp_forward( // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -817,8 +748,7 @@ __global__ void additive_masked_softmax_warp_forward( input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const half *curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -831,8 +761,7 @@ __global__ void additive_masked_softmax_warp_forward( if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); + copy_vector(&elements_input[i][it], src + itr_idx); // apply_mask(&elements_input[i][it], // (__half)-std::numeric_limits::infinity(), // curr_mask + itr_jmp); @@ -863,8 +792,7 @@ __global__ void additive_masked_softmax_warp_forward( #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -906,8 +834,7 @@ __global__ void additive_masked_softmax_warp_forward( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -917,8 +844,7 @@ __global__ void additive_masked_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } @@ -931,14 +857,13 @@ __global__ void additive_masked_softmax_warp_forward( // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using additive_masked_softmax_forward_func = void (*)( - input_t *dst, const output_t *src, const half *pad_mask, int batch_size, - int stride, int element_count, int pad_batch_stride); +using additive_masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src, const half *pad_mask, + int batch_size, int stride, int element_count, + int pad_batch_stride); template -bool warp_additive_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - additive_masked_softmax_forward_func &kernel) { +bool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, + additive_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -947,75 +872,60 @@ bool warp_additive_masked_softmax_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = &additive_masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = &additive_masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = &additive_masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = &additive_masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = &additive_masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = &additive_masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = &additive_masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = &additive_masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = &additive_masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = &additive_masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = &additive_masked_softmax_warp_forward; - break; - default: - return false; + case 0: // 1 + kernel = &additive_masked_softmax_warp_forward; + break; + case 1: // 2 + kernel = &additive_masked_softmax_warp_forward; + break; + case 2: // 4 + kernel = &additive_masked_softmax_warp_forward; + break; + case 3: // 8 + kernel = &additive_masked_softmax_warp_forward; + break; + case 4: // 16 + kernel = &additive_masked_softmax_warp_forward; + break; + case 5: // 32 + kernel = &additive_masked_softmax_warp_forward; + break; + case 6: // 64 + kernel = &additive_masked_softmax_warp_forward; + break; + case 7: // 128 + kernel = &additive_masked_softmax_warp_forward; + break; + case 8: // 256 + kernel = &additive_masked_softmax_warp_forward; + break; + case 9: // 512 + kernel = &additive_masked_softmax_warp_forward; + break; + case 10: // 1024 + kernel = &additive_masked_softmax_warp_forward; + break; + default: + return false; } return true; } template -bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, - const input_t *pad_mask, - int softmax_elements, - int softmax_elements_stride, - int batch_count, int pad_batch_stride) { +bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, + int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; additive_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_additive_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, + kernel)) { return false; } @@ -1032,30 +942,27 @@ bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, // launch kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); + dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } template -bool dispatch_additive_masked_softmax_stream( - output_t *dst, const input_t *src, const input_t *pad_mask, - int softmax_elements, int softmax_elements_stride, int batch_count, - int pad_batch_stride, cudaStream_t streamid) { +bool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src, const input_t *pad_mask, + int softmax_elements, int softmax_elements_stride, int batch_count, + int pad_batch_stride, cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; additive_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_additive_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, + kernel)) { return false; } // use 128 threads per block to maximize gpu utilization @@ -1067,9 +974,8 @@ bool dispatch_additive_masked_softmax_stream( int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); + kernel<<>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, + pad_batch_stride); return true; } return false; @@ -1079,12 +985,10 @@ bool dispatch_additive_masked_softmax_stream( // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void -masked_softmax_warp_forward(input_t *dst, const output_t *src, - const uint8_t *pad_mask, int batch_size, int stride, - int element_count, int pad_batch_stride) { +template +__global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, + int stride, int element_count, int pad_batch_stride) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -1092,8 +996,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -1107,26 +1010,21 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t *curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); + elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); + copy_vector(&elements_input[i][it], src + itr_idx); apply_mask( - &elements_input[i][it], - __float2half(-std::numeric_limits::infinity()), - curr_mask + itr_jmp); + &elements_input[i][it], __float2half(-std::numeric_limits::infinity()), curr_mask + itr_jmp); } } } @@ -1153,8 +1051,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -1196,8 +1093,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -1207,8 +1103,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } @@ -1221,16 +1116,12 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src, - const uint8_t *pad_mask, - int batch_size, int stride, - int element_count, - int pad_batch_stride); +using masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, + int stride, int element_count, int pad_batch_stride); template -bool warp_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - masked_softmax_forward_func &kernel) { +bool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, + masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -1239,70 +1130,59 @@ bool warp_masked_softmax_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = &masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = &masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = &masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = &masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = - &masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = - &masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = - &masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = - &masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = - &masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = - &masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = - &masked_softmax_warp_forward; - break; - default: - return false; + case 0: // 1 + kernel = &masked_softmax_warp_forward; + break; + case 1: // 2 + kernel = &masked_softmax_warp_forward; + break; + case 2: // 4 + kernel = &masked_softmax_warp_forward; + break; + case 3: // 8 + kernel = &masked_softmax_warp_forward; + break; + case 4: // 16 + kernel = &masked_softmax_warp_forward; + break; + case 5: // 32 + kernel = &masked_softmax_warp_forward; + break; + case 6: // 64 + kernel = &masked_softmax_warp_forward; + break; + case 7: // 128 + kernel = &masked_softmax_warp_forward; + break; + case 8: // 256 + kernel = &masked_softmax_warp_forward; + break; + case 9: // 512 + kernel = &masked_softmax_warp_forward; + break; + case 10: // 1024 + kernel = &masked_softmax_warp_forward; + break; + default: + return false; } return true; } template -bool dispatch_masked_softmax(output_t *dst, const input_t *src, - const uint8_t *pad_mask, int softmax_elements, - int softmax_elements_stride, int batch_count, - int pad_batch_stride) { +bool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, + int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; masked_softmax_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } @@ -1319,8 +1199,7 @@ bool dispatch_masked_softmax(output_t *dst, const input_t *src, // launch kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); + dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; @@ -1330,11 +1209,10 @@ bool dispatch_masked_softmax(output_t *dst, const input_t *src, // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void time_masked_softmax_warp_forward( - input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, - int stride, int element_count, int mod_seq_len) { +template +__global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, + int batch_size, int stride, int element_count, int mod_seq_len) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -1342,8 +1220,7 @@ __global__ void time_masked_softmax_warp_forward( // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -1357,26 +1234,21 @@ __global__ void time_masked_softmax_warp_forward( input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t *curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); + elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); + copy_vector(&elements_input[i][it], src + itr_idx); apply_mask( - &elements_input[i][it], - __float2half(-std::numeric_limits::infinity()), - curr_mask + itr_jmp); + &elements_input[i][it], __float2half(-std::numeric_limits::infinity()), curr_mask + itr_jmp); } } } @@ -1403,8 +1275,7 @@ __global__ void time_masked_softmax_warp_forward( #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -1446,8 +1317,7 @@ __global__ void time_masked_softmax_warp_forward( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -1457,8 +1327,7 @@ __global__ void time_masked_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } @@ -1471,14 +1340,12 @@ __global__ void time_masked_softmax_warp_forward( // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using time_masked_softmax_forward_func = - void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, - int batch_size, int stride, int element_count, int mod_seq_len); +using time_masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, + int batch_size, int stride, int element_count, int mod_seq_len); template -bool warp_time_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - time_masked_softmax_forward_func &kernel) { +bool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, + time_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -1487,74 +1354,60 @@ bool warp_time_masked_softmax_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = - &time_masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = - &time_masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = - &time_masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = - &time_masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = &time_masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = &time_masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = &time_masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = &time_masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = &time_masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = &time_masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = &time_masked_softmax_warp_forward; - break; - default: - return false; + case 0: // 1 + kernel = &time_masked_softmax_warp_forward; + break; + case 1: // 2 + kernel = &time_masked_softmax_warp_forward; + break; + case 2: // 4 + kernel = &time_masked_softmax_warp_forward; + break; + case 3: // 8 + kernel = &time_masked_softmax_warp_forward; + break; + case 4: // 16 + kernel = &time_masked_softmax_warp_forward; + break; + case 5: // 32 + kernel = &time_masked_softmax_warp_forward; + break; + case 6: // 64 + kernel = &time_masked_softmax_warp_forward; + break; + case 7: // 128 + kernel = &time_masked_softmax_warp_forward; + break; + case 8: // 256 + kernel = &time_masked_softmax_warp_forward; + break; + case 9: // 512 + kernel = &time_masked_softmax_warp_forward; + break; + case 10: // 1024 + kernel = &time_masked_softmax_warp_forward; + break; + default: + return false; } return true; } template -bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, - const uint8_t *pad_mask, int softmax_elements, - int softmax_elements_stride, int batch_count, - int mod_seq_len) { +bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, + int softmax_elements_stride, int batch_count, int mod_seq_len) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; time_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; - if (!warp_time_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_time_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, + kernel)) { return false; } @@ -1571,8 +1424,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, // launch kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, mod_seq_len); + dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len); return true; } return false; @@ -1580,15 +1432,13 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, int log2_ceil_native(int value) { int log2_value = 0; - while ((1 << log2_value) < value) - ++log2_value; + while ((1 << log2_value) < value) ++log2_value; return log2_value; } template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else @@ -1616,17 +1466,15 @@ __device__ __forceinline__ void warp_reduce_sum(acc_t *sum) { // softmax backward data function is taken from native pytorch, elementwise mul // is fused in the epolog, as well as masking and scaling for fusing dropout -template -__global__ void masked_scale_softmax_warp_backward_masked_dgrad( - output_t *gradInput, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, - int stride, int element_count, int heads) { +template +__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradInput, const input_t *grad, + const input_t *output, const uint8_t *mask, + const uint8_t *pad_mask, acc_t scale, int batch_size, + int stride, int element_count, int heads) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -1635,8 +1483,7 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad( // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -1662,11 +1509,9 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad( for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - grad_reg[i][it] = - (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * - (acc_t)grad[i * element_count + it * WARP_SIZE] * - (acc_t)scale) * - output[i * element_count + it * WARP_SIZE]; + grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * + (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) * + output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); @@ -1689,8 +1534,7 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; @@ -1698,32 +1542,26 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad( // compute gradients int total_ind = thread_offset + i * element_count + it * WARP_SIZE; int pad_mask_ind = - element_count * - (total_ind / (heads * element_count * element_count)) + - total_ind % element_count; + element_count * (total_ind / (heads * element_count * element_count)) + total_ind % element_count; uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind]; if (pad_mask_element == 0) gradInput[i * element_count + it * WARP_SIZE] = 0; else { if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } } -template -void dispatch_masked_scale_softmax_backward_masked_out( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int batch_count, - int heads) { +template +void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, const input_t *grad, const input_t *output, + const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, + int softmax_elements, int softmax_elements_stride, + int batch_count, int heads) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; @@ -1733,8 +1571,7 @@ void dispatch_masked_scale_softmax_backward_masked_out( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -1749,96 +1586,84 @@ void dispatch_masked_scale_softmax_backward_masked_out( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 1: // 2 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 2: // 4 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 3: // 8 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 4: // 16 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 5: // 32 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 6: // 64 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 7: // 128 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 8: // 256 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 9: // 512 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - default: - break; + case 0: // 1 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 1: // 2 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 2: // 4 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 3: // 8 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 4: // 16 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 5: // 32 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 6: // 64 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 7: // 128 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 8: // 256 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 9: // 512 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + case 10: // 1024 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, + batch_count, softmax_elements_stride, + softmax_elements, heads); + break; + default: + break; } } } -template -void dispatch_masked_scale_softmax_backward_masked_out_stream( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int batch_count, - int heads, cudaStream_t streamid) { +template +void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_input, const input_t *grad, + const input_t *output, const uint8_t *mask, + const uint8_t *pad_mask, acc_t scale, + int softmax_elements, int softmax_elements_stride, + int batch_count, int heads, cudaStream_t streamid) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; @@ -1847,8 +1672,7 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream( const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1860,101 +1684,75 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 1: // 2 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 2: // 4 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 3: // 8 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 4: // 16 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 5: // 32 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 6: // 64 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 7: // 128 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 8: // 256 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 9: // 512 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - default: - break; + case 0: // 1 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 1: // 2 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 2: // 4 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 3: // 8 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 4: // 16 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 5: // 32 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 6: // 64 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 7: // 128 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 8: // 256 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 9: // 512 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + case 10: // 1024 + masked_scale_softmax_warp_backward_masked_dgrad + <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, + softmax_elements_stride, softmax_elements, heads); + break; + default: + break; } } } -template -__global__ void -masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, - const input_t *output, const uint8_t *mask, - acc_t scale, int batch_size, int stride, - int element_count) { +template +__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, + const uint8_t *mask, acc_t scale, int batch_size, int stride, + int element_count) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -1963,8 +1761,7 @@ masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -1990,11 +1787,9 @@ masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - grad_reg[i][it] = - (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * - (acc_t)grad[i * element_count + it * WARP_SIZE] * - (acc_t)scale) * - output[i * element_count + it * WARP_SIZE]; + grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * + (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) * + output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); @@ -2017,39 +1812,34 @@ masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } -template -__global__ void masked_scale_softmax_warp_backward_recompute( - output_t *gradInput, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, - int stride, int pad_batch_stride, int element_count) { +template +__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, + const input_t *softmax_input, const input_t *pad_mask, + const uint8_t *mask, acc_t scale, int batch_size, + int stride, int pad_batch_stride, int element_count) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -2075,8 +1865,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const input_t *curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { @@ -2092,23 +1881,17 @@ __global__ void masked_scale_softmax_warp_backward_recompute( if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - softmax_input + itr_idx); + copy_vector(&elements_input[i][it], softmax_input + itr_idx); apply_additive_mask( &elements_input[i][it], - curr_mask + - itr_jmp); //(__half)-std::numeric_limits::infinity() + curr_mask + itr_jmp); //(__half)-std::numeric_limits::infinity() uint8_t mask_temp[ELEMENTS_PER_LDG_STG]; input_t grad_temp[ELEMENTS_PER_LDG_STG]; - copy_vector(&mask_temp[0], - mask + itr_idx); - copy_vector(&grad_temp[0], - grad + itr_idx); + copy_vector(&mask_temp[0], mask + itr_idx); + copy_vector(&grad_temp[0], grad + itr_idx); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = - ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * - (acc_t)scale); + grad_reg[i][it + element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale); } } } @@ -2140,8 +1923,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } @@ -2183,8 +1965,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it++) { elements[i][it] = elements[i][it] / sum[i]; @@ -2206,8 +1987,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -2217,34 +1997,28 @@ __global__ void masked_scale_softmax_warp_backward_recompute( #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) { if (is_log_softmax) { - grad_input_reg[element] = - (grad_reg[i][it + element] - - std::exp(elements[i][it + element]) * grad_sum[i]); + grad_input_reg[element] = (grad_reg[i][it + element] - std::exp(elements[i][it + element]) * grad_sum[i]); } else { - grad_input_reg[element] = (grad_reg[i][it + element] - - elements[i][it + element] * grad_sum[i]); + grad_input_reg[element] = (grad_reg[i][it + element] - elements[i][it + element] * grad_sum[i]); } } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, grad_input_reg); + copy_vector(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg); } } } } -template -using masked_scale_softmax_warp_backward_recompute_func = void (*)( - output_t *gradInput, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, - int stride, int pad_batch_stride, int element_count); +template +using masked_scale_softmax_warp_backward_recompute_func = void (*)(output_t *gradInput, const input_t *grad, + const input_t *softmax_input, + const input_t *pad_mask, const uint8_t *mask, + acc_t scale, int batch_size, int stride, + int pad_batch_stride, int element_count); -template +template bool masked_scale_softmax_warp_backward_recompute_kernel( int element_count, int log2_elements, int &warp_size, int &batches_per_warp, - masked_scale_softmax_warp_backward_recompute_func &kernel) { + masked_scale_softmax_warp_backward_recompute_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -2253,101 +2027,78 @@ bool masked_scale_softmax_warp_backward_recompute_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; bool flag_vec4 = (element_count % 4 == 0); switch (log2_elements) { - case 0: // 1 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 1, 1, is_log_softmax>; - break; - case 1: // 2 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 2, 1, is_log_softmax>; - break; - case 2: // 4 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 4, 1, is_log_softmax>; - break; - case 3: // 8 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 8, 1, is_log_softmax>; - break; - case 4: // 16 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 16, 1, is_log_softmax>; - break; - case 5: // 32 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 32, 1, is_log_softmax>; - break; - case 6: // 64 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 2, 32, 1, is_log_softmax>; - break; - case 7: // 128 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 4, 32, 1, is_log_softmax>; - break; - case 8: // 256 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 8, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 8, 32, 1, is_log_softmax>; - break; - case 9: // 512 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 16, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 16, 32, 1, is_log_softmax>; - break; - case 10: // 1024 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 32, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 32, 32, 1, is_log_softmax>; - break; - case 11: // 2048 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 64, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 64, 32, 1, is_log_softmax>; - break; - default: - return false; + case 0: // 1 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 1: // 2 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 2: // 4 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 3: // 8 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 4: // 16 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 5: // 32 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 6: // 64 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 7: // 128 + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 8: // 256 + if (flag_vec4) + kernel = &masked_scale_softmax_warp_backward_recompute; + else + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 9: // 512 + if (flag_vec4) + kernel = &masked_scale_softmax_warp_backward_recompute; + else + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 10: // 1024 + if (flag_vec4) + kernel = &masked_scale_softmax_warp_backward_recompute; + else + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + case 11: // 2048 + if (flag_vec4) + kernel = &masked_scale_softmax_warp_backward_recompute; + else + kernel = &masked_scale_softmax_warp_backward_recompute; + break; + default: + return false; } return true; } -template -bool dispatch_masked_scale_softmax_backward_recompute( - output_t *grad_input, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int pad_batch_stride, - int batch_count, cudaStream_t streamid) { - +template +bool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, + const input_t *softmax_input, const input_t *pad_mask, + const uint8_t *mask, acc_t scale, int softmax_elements, + int softmax_elements_stride, int pad_batch_stride, + int batch_count, cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 2048) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; - masked_scale_softmax_warp_backward_recompute_func - kernel; + masked_scale_softmax_warp_backward_recompute_func kernel; int warp_size, batches_per_warp; - if (!masked_scale_softmax_warp_backward_recompute_kernel< - input_t, output_t, acc_t, is_log_softmax>( - softmax_elements, log2_elements, warp_size, batches_per_warp, - kernel)) { + if (!masked_scale_softmax_warp_backward_recompute_kernel( + softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) { return false; } @@ -2362,20 +2113,18 @@ bool dispatch_masked_scale_softmax_backward_recompute( dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, - softmax_elements_stride, pad_batch_stride, softmax_elements); + kernel<<>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, + softmax_elements_stride, pad_batch_stride, softmax_elements); return true; } return false; } -template -void dispatch_masked_scale_softmax_backward_stream( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, acc_t scale, int softmax_elements, - int softmax_elements_stride, int batch_count, cudaStream_t streamid) { +template +void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, + const uint8_t *mask, acc_t scale, int softmax_elements, + int softmax_elements_stride, int batch_count, + cudaStream_t streamid) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; @@ -2384,8 +2133,7 @@ void dispatch_masked_scale_softmax_backward_stream( const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2397,85 +2145,63 @@ void dispatch_masked_scale_softmax_backward_stream( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - default: - break; + case 0: // 1 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + masked_scale_softmax_warp_backward + <<>>(grad_input, grad, output, mask, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } } } @@ -2483,17 +2209,13 @@ void dispatch_masked_scale_softmax_backward_stream( // elementwise multiplication called in at::softmax_backward_data is fused // inside softmax dgrad kernel as a result of fusion, intermediate // multiplication result is stored in fp32 in registers, instead of fp16 -template -__global__ void -softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, - const input_t *output, int batch_size, - int stride, int element_count) { +template +__global__ void softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, const input_t *output, + int batch_size, int stride, int element_count) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -2502,8 +2224,7 @@ softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -2528,8 +2249,7 @@ softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] * - output[i * element_count + it * WARP_SIZE]; + grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] * output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); @@ -2541,10 +2261,10 @@ softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; //* output_reg[i][0]; + sum[i] = grad_reg[i][0]; //* output_reg[i][0]; #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; // * output_reg[i][it]; + sum[i] += grad_reg[i][it]; // * output_reg[i][it]; } } warp_reduce_sum(sum); @@ -2552,30 +2272,25 @@ softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } -template -void dispatch_softmax_backward_fused_native( - output_t *grad_input, const input_t *grad, const input_t *output, - int softmax_elements, int softmax_elements_stride, int batch_count) { +template +void dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t *grad, const input_t *output, + int softmax_elements, int softmax_elements_stride, int batch_count) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; @@ -2585,8 +2300,7 @@ void dispatch_softmax_backward_fused_native( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -2601,85 +2315,63 @@ void dispatch_softmax_backward_fused_native( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - default: - break; + case 0: // 1 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + softmax_warp_backward_fused_native + <<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } } } @@ -2688,18 +2380,16 @@ void dispatch_softmax_backward_fused_native( // Warp softmax backward //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -template -__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, - const __half *output, int batch_size, +template +__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, int batch_size, int stride, int element_count) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -2721,11 +2411,9 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - &grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); + copy_vector(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); copy_vector(&output_reg_input[i][it], - output + i * element_count + - it * WARP_SIZE); + output + i * element_count + it * WARP_SIZE); } } } @@ -2763,8 +2451,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; @@ -2772,12 +2459,10 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_reg[i][it + element] * - (grad_reg[i][it + element] - sum[i])); + out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i])); } // store them in global memory - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); } } } @@ -2788,14 +2473,12 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad, - const input_t *output, int batch_size, +using softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count); template -bool warp_softmax_backward_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - softmax_backward_func &kernel) { +bool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, + softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -2804,48 +2487,47 @@ bool warp_softmax_backward_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = &softmax_warp_backward; - break; - case 1: // 2 - kernel = &softmax_warp_backward; - break; - case 2: // 4 - kernel = &softmax_warp_backward; - break; - case 3: // 8 - kernel = &softmax_warp_backward; - break; - case 4: // 16 - kernel = &softmax_warp_backward; - break; - case 5: // 32 - kernel = &softmax_warp_backward; - break; - case 6: // 64 - kernel = &softmax_warp_backward; - break; - case 7: // 128 - kernel = &softmax_warp_backward; - break; - case 8: // 256 - kernel = &softmax_warp_backward; - break; - case 9: // 512 - kernel = &softmax_warp_backward; - break; - case 10: // 1024 - kernel = &softmax_warp_backward; - break; - default: - return false; + case 0: // 1 + kernel = &softmax_warp_backward; + break; + case 1: // 2 + kernel = &softmax_warp_backward; + break; + case 2: // 4 + kernel = &softmax_warp_backward; + break; + case 3: // 8 + kernel = &softmax_warp_backward; + break; + case 4: // 16 + kernel = &softmax_warp_backward; + break; + case 5: // 32 + kernel = &softmax_warp_backward; + break; + case 6: // 64 + kernel = &softmax_warp_backward; + break; + case 7: // 128 + kernel = &softmax_warp_backward; + break; + case 8: // 256 + kernel = &softmax_warp_backward; + break; + case 9: // 512 + kernel = &softmax_warp_backward; + break; + case 10: // 1024 + kernel = &softmax_warp_backward; + break; + default: + return false; } return true; } template -bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, - const input_t *output, int softmax_elements, +bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { return true; @@ -2853,13 +2535,11 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_backward_func kernel; int warp_size, batches_per_warp; - if (!warp_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } @@ -2875,32 +2555,27 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); + kernel<<>>(grad_input, grad, output, batch_count, + softmax_elements_stride, softmax_elements); return true; } return false; } template -bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, - const input_t *output, - int softmax_elements, - int softmax_elements_stride, - int batch_count, cudaStream_t streamid) { +bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, + int softmax_elements, int softmax_elements_stride, int batch_count, + cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_backward_func kernel; int warp_size, batches_per_warp; - if (!warp_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization @@ -2912,28 +2587,24 @@ bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch - kernel<<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); + kernel<<>>(grad_input, grad, output, batch_count, softmax_elements_stride, + softmax_elements); return true; } return false; } -template -__global__ void -masked_softmax_warp_backward(__half *gradInput, const __half *grad, - const __half *output, const uint8_t *pad_mask, - int batch_size, int stride, int element_count, - int pad_batch_stride) { +template +__global__ void masked_softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, + const uint8_t *pad_mask, int batch_size, int stride, int element_count, + int pad_batch_stride) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch @@ -2955,11 +2626,9 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - &grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); + copy_vector(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); copy_vector(&output_reg_input[i][it], - output + i * element_count + - it * WARP_SIZE); + output + i * element_count + it * WARP_SIZE); } } } @@ -2997,10 +2666,8 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; + if (i >= local_batches) break; + int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t *curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { @@ -3009,16 +2676,14 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_reg[i][it + element] * - (grad_reg[i][it + element] - sum[i])); + out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i])); } // store them in global memory int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; // It is kind of unfortunate this has to be here to zero something out // that is close to zero in the first place - apply_mask(&out[0], 0.0, - curr_mask + itr_jmp); + apply_mask(&out[0], 0.0, curr_mask + itr_jmp); copy_vector(gradInput + itr_idx, out); } } @@ -3030,15 +2695,13 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template -using masked_softmax_backward_func = - void (*)(output_t *gradInput, const input_t *grad, const input_t *output, - const uint8_t *pad_mask, int batch_size, int stride, - int element_count, int pad_batch_stride); +using masked_softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad, const input_t *output, + const uint8_t *pad_mask, int batch_size, int stride, int element_count, + int pad_batch_stride); template -bool warp_masked_softmax_backward_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - masked_softmax_backward_func &kernel) { +bool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, + masked_softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; @@ -3047,62 +2710,48 @@ bool warp_masked_softmax_backward_kernel( batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { - case 0: // 1 - kernel = - &masked_softmax_warp_backward; - break; - case 1: // 2 - kernel = - &masked_softmax_warp_backward; - break; - case 2: // 4 - kernel = - &masked_softmax_warp_backward; - break; - case 3: // 8 - kernel = - &masked_softmax_warp_backward; - break; - case 4: // 16 - kernel = - &masked_softmax_warp_backward; - break; - case 5: // 32 - kernel = - &masked_softmax_warp_backward; - break; - case 6: // 64 - kernel = - &masked_softmax_warp_backward; - break; - case 7: // 128 - kernel = - &masked_softmax_warp_backward; - break; - case 8: // 256 - kernel = - &masked_softmax_warp_backward; - break; - case 9: // 512 - kernel = - &masked_softmax_warp_backward; - break; - case 10: // 1024 - kernel = - &masked_softmax_warp_backward; - break; - default: - return false; + case 0: // 1 + kernel = &masked_softmax_warp_backward; + break; + case 1: // 2 + kernel = &masked_softmax_warp_backward; + break; + case 2: // 4 + kernel = &masked_softmax_warp_backward; + break; + case 3: // 8 + kernel = &masked_softmax_warp_backward; + break; + case 4: // 16 + kernel = &masked_softmax_warp_backward; + break; + case 5: // 32 + kernel = &masked_softmax_warp_backward; + break; + case 6: // 64 + kernel = &masked_softmax_warp_backward; + break; + case 7: // 128 + kernel = &masked_softmax_warp_backward; + break; + case 8: // 256 + kernel = &masked_softmax_warp_backward; + break; + case 9: // 512 + kernel = &masked_softmax_warp_backward; + break; + case 10: // 1024 + kernel = &masked_softmax_warp_backward; + break; + default: + return false; } return true; } template -bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, - const input_t *output, - const uint8_t *pad_mask, - int softmax_elements, - int softmax_elements_stride, +bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, + const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; @@ -3110,13 +2759,12 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; + while ((1 << log2_elements) < softmax_elements) ++log2_elements; masked_softmax_backward_func kernel; int warp_size, batches_per_warp; - if (!warp_masked_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { + if (!warp_masked_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, + kernel)) { return false; } @@ -3133,10 +2781,9 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, // launch kernel<<>>( - grad_input, grad, output, pad_mask, batch_count, - softmax_elements_stride, softmax_elements, pad_batch_stride); + grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } -} // namespace +} // namespace diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh index b207b17fb..8db945945 100644 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh @@ -1,22 +1,21 @@ #pragma once -#include -#include - #include #include #include #include -//#include +#include +#include + +// #include #include #include - #include -#include +#include #include +#include #include #include -#include #include namespace { @@ -33,11 +32,9 @@ cublasOperation_t convertTransToCublasOperation(char trans) { } } -void CublasStridedBatchedGemm( - char transa, char transb, long m, long n, long k, - float alpha, const half *a, long lda, long strideA, const half *b, long ldb, - long strideB, float beta, half *c, long ldc, long strideC, long batchCount, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { +void CublasStridedBatchedGemm(char transa, char transb, long m, long n, long k, float alpha, const half *a, long lda, + long strideA, const half *b, long ldb, long strideB, float beta, half *c, long ldc, + long strideC, long batchCount, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opb = convertTransToCublasOperation(transb); @@ -47,466 +44,298 @@ void CublasStridedBatchedGemm( float fAlpha = alpha; float fBeta = beta; TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx( - handle, opa, opb, (int)m, (int)n, (int)k, (void *)&fAlpha, a, CUDA_R_16F, - (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, (void *)&fBeta, c, - CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo)); + handle, opa, opb, (int)m, (int)n, (int)k, (void *)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, b, CUDA_R_16F, + (int)ldb, strideB, (void *)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo)); } -} // namespace +} // namespace // TODO(mkozuki): Make use of the int template parameters or discard them. template -void CutlassGemm_FP32Accum( - cudaStream_t stream, - long m, long n, long k, float alpha, - const half* a, long lda, long long int batch_stride_A, - const half* b, long ldb, long long int batch_stride_B, - float beta, - half* c, long ldc, long long int batch_stride_C, long batch_count -) { +void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, float alpha, const half *a, long lda, + long long int batch_stride_A, const half *b, long ldb, long long int batch_stride_B, + float beta, half *c, long ldc, long long int batch_stride_C, long batch_count) { using Gemm = cutlass::gemm::device::GemmBatched< - /* Element type of A matrix */half, /* Layout of A matrix */LayoutA, - /* Element type of B matrix */half, /* Layout of B matrix */LayoutB, - /* Element type of C matrix */half, /* Layout of C matrix */cutlass::layout::ColumnMajor, - /* Element Accumulator*/float - >; + /* Element type of A matrix */ half, /* Layout of A matrix */ LayoutA, + /* Element type of B matrix */ half, /* Layout of B matrix */ LayoutB, + /* Element type of C matrix */ half, /* Layout of C matrix */ cutlass::layout::ColumnMajor, + /* Element Accumulator*/ float>; Gemm gemm_op; - cutlass::Status status = gemm_op({ - {static_cast(m), static_cast(n), static_cast(k)}, - {a, lda}, batch_stride_A, - {b, ldb}, batch_stride_B, - {c, ldc}, batch_stride_C, - {c, ldc}, batch_stride_C, - {alpha, beta}, static_cast(batch_count) - }, nullptr, stream); + cutlass::Status status = gemm_op({{static_cast(m), static_cast(n), static_cast(k)}, + {a, lda}, + batch_stride_A, + {b, ldb}, + batch_stride_B, + {c, ldc}, + batch_stride_C, + {c, ldc}, + batch_stride_C, + {alpha, beta}, + static_cast(batch_count)}, + nullptr, stream); C10_CUDA_CHECK(status != cutlass::Status::kSuccess ? cudaErrorUnknown : cudaSuccess); } namespace { -void gemm_switch_fp32accum(char transa, char transb, long m, - long n, long k, float alpha, const half *a, long lda, - long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, - long batchCount) { +void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, float alpha, const half *a, long lda, + long strideA, const half *b, long ldb, long strideB, float beta, half *c, long ldc, + long strideC, long batchCount) { auto stream = c10::cuda::getCurrentCUDAStream(); // printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta); if ((transa == 't') && (transb == 'n')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if ((transa == 'n') && (transb == 'n')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if ((transa == 'n') && (transb == 't')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); + CutlassGemm_FP32Accum( + stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, + CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else { @@ -514,52 +343,46 @@ void gemm_switch_fp32accum(char transa, char transb, long m, } } -void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, - int64_t *lda, int64_t *ldb, int64_t *ldc) { +void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, + int64_t *ldc) { int transa_ = ((transa == 't') || (transa == 'T')); int transb_ = ((transb == 't') || (transb == 'T')); // Note: leading dimensions generally are checked that they are > 0 and at // least as big the result requires (even if the value won't be used). - if (n <= 1) - *ldc = std::max(m, 1); + if (n <= 1) *ldc = std::max(m, 1); if (transa_) { - if (m <= 1) - *lda = std::max(k, 1); + if (m <= 1) *lda = std::max(k, 1); } else { - if (k <= 1) - *lda = std::max(m, 1); + if (k <= 1) *lda = std::max(m, 1); } if (transb_) { - if (k <= 1) - *ldb = std::max(n, 1); + if (k <= 1) *ldb = std::max(n, 1); } else { - if (n <= 1) - *ldb = std::max(k, 1); + if (n <= 1) *ldb = std::max(k, 1); } } -void HgemmStridedBatched(char transa, char transb, long m, - long n, long k, float alpha, const half *a, long lda, - long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, - long batchCount) { - if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || - (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX)) +void HgemmStridedBatched(char transa, char transb, long m, long n, long k, float alpha, const half *a, long lda, + long strideA, const half *b, long ldb, long strideB, float beta, half *c, long ldc, + long strideC, long batchCount) { + if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || + (batchCount >= INT_MAX)) { - TORCH_CHECK(false, "Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, " - "batchCount" - "with the bound [val] <= %d", - INT_MAX); + TORCH_CHECK(false, + "Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, " + "batchCount" + "with the bound [val] <= %d", + INT_MAX); } adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, - b, ldb, strideB, beta, c, ldc, strideC, batchCount); + gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, + batchCount); } -} // namespace +} // namespace diff --git a/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp b/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp index 2cc1da921..3f4b482f7 100644 --- a/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp +++ b/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp @@ -1,19 +1,17 @@ #include #include +#include #include #include -#include - -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t result = cmd; \ - if (result != ncclSuccess) { \ - std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + \ - std::string(ncclGetErrorString(result)); \ - TORCH_CHECK(false, err); \ - } \ +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t result = cmd; \ + if (result != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + std::to_string(__LINE__) + ", " + \ + std::string(ncclGetErrorString(result)); \ + TORCH_CHECK(false, err); \ + } \ } while (0) void *nccl_alloc_plug(size_t size, int device, void *stream) { @@ -22,22 +20,18 @@ void *nccl_alloc_plug(size_t size, int device, void *stream) { return ptr; } -void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) { - NCCL_CHECK(ncclMemFree(ptr)); -} +void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) { NCCL_CHECK(ncclMemFree(ptr)); } std::shared_ptr nccl_allocator; void maybe_init() { if (!nccl_allocator) { - nccl_allocator = std::make_shared< - torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>( - nccl_alloc_plug, nccl_free_plug); + nccl_allocator = + std::make_shared(nccl_alloc_plug, nccl_free_plug); } } -std::shared_ptr -get_nccl_allocator() { +std::shared_ptr get_nccl_allocator() { maybe_init(); return nccl_allocator; } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp index 6bf789c50..3db57c4e9 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp @@ -17,9 +17,13 @@ #include "nccl_p2p_cuda.cuh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", py::call_guard()); - m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", py::call_guard()); - m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace", py::call_guard()); - m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", py::call_guard()); + m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", + py::call_guard()); + m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", + py::call_guard()); + m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, + "left_right_halo_exchange_inplace", py::call_guard()); + m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", + py::call_guard()); m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu index c386dcfb7..21091bd1d 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu @@ -1,10 +1,12 @@ -#include -#include #include -#include +#include +#include + +#include #include #include -#include +#include + #include "nccl.h" /* @@ -15,197 +17,193 @@ namespace { __global__ void AddDelay_kernel(const int delay, int* counter) { - if (blockIdx.x == 0 && threadIdx.x == 0) { - // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. - int new_counter = 0; - double elapsed = 0; - clock_t start = clock(); - do { - clock_t now = clock(); - elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC; - ++new_counter; - } while (elapsed < (double)delay); - *counter = new_counter; - } + if (blockIdx.x == 0 && threadIdx.x == 0) { + // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. + int new_counter = 0; + double elapsed = 0; + clock_t start = clock(); + do { + clock_t now = clock(); + elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC; + ++new_counter; + } while (elapsed < (double)delay); + *counter = new_counter; + } } -class NcclCommWrapper -{ - private: - ncclComm_t comm; - int rank, world_size; - - ncclDataType_t get_nccl_type(at::Tensor input) - { - switch (input.scalar_type()) - { - case at::ScalarType::Half: - return ncclFloat16; - case at::ScalarType::Float: - return ncclFloat32; - case at::ScalarType::Double: - return ncclFloat64; - case at::ScalarType::Byte: - return ncclUint8; - case at::ScalarType::Char: - return ncclInt8; - case at::ScalarType::Int: - return ncclInt32; - case at::ScalarType::Long: - return ncclInt64; - case at::ScalarType::BFloat16: - return ncclBfloat16; - default: - assert(false); - } - } - - public: - NcclCommWrapper() - { - memset(&comm, 0, sizeof(ncclComm_t)); - rank = 0; - world_size = 0; - } - NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks) - { - ncclCommInitRank(&comm, num_ranks, id, my_rank); - rank = my_rank; - world_size = num_ranks; - } - - ~NcclCommWrapper() - { - printf("ncclCommDestroy()\n"); - ncclCommDestroy(comm); - } - - void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) - { - auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); - ncclDataType_t ncclType = get_nccl_type(left_output_halo); - bool left_zero = (left_rank < 0); - bool right_zero = (right_rank < 0); - size_t left_n = torch::numel(left_output_halo); - size_t right_n = torch::numel(right_output_halo); - assert(left_n > 0 && left_n == right_n); - if (left_zero) { - left_input_halo.zero_(); - } else { - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() { - // send left (to my_rank - 1) - ncclSend(left_output_halo.data_ptr(), left_n, ncclType, left_rank, comm, stream); - // receive left (from my_rank - 1) - ncclRecv(left_input_halo.data_ptr(), right_n, ncclType, left_rank, comm, stream); - }); - } - if (right_zero) { - right_input_halo.zero_(); - } else { - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() { - // send right (to my_rank + 1 ) - ncclSend(right_output_halo.data_ptr(), right_n, ncclType, right_rank, comm, stream); - // receive right (from my_rank + 1) - ncclRecv(right_input_halo.data_ptr(), left_n, ncclType, right_rank, comm, stream); - }); - } - ncclGroupEnd(); - } - - std::vector left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) - { - // after halo exchange: - // left_output_halo of rank+1 ends up in right_input_halo of rank - // right_output_halo of rank-1 ends up in left_input_halo of rank - auto right_input_halo = torch::empty_like(left_output_halo); - auto left_input_halo = torch::empty_like(right_output_halo); - left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); - return {left_input_halo, right_input_halo}; - } +class NcclCommWrapper { + private: + ncclComm_t comm; + int rank, world_size; + + ncclDataType_t get_nccl_type(at::Tensor input) { + switch (input.scalar_type()) { + case at::ScalarType::Half: + return ncclFloat16; + case at::ScalarType::Float: + return ncclFloat32; + case at::ScalarType::Double: + return ncclFloat64; + case at::ScalarType::Byte: + return ncclUint8; + case at::ScalarType::Char: + return ncclInt8; + case at::ScalarType::Int: + return ncclInt32; + case at::ScalarType::Long: + return ncclInt64; + case at::ScalarType::BFloat16: + return ncclBfloat16; + default: + assert(false); + } + } + + public: + NcclCommWrapper() { + memset(&comm, 0, sizeof(ncclComm_t)); + rank = 0; + world_size = 0; + } + NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks) { + ncclCommInitRank(&comm, num_ranks, id, my_rank); + rank = my_rank; + world_size = num_ranks; + } + + ~NcclCommWrapper() { + printf("ncclCommDestroy()\n"); + ncclCommDestroy(comm); + } + + void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo, at::Tensor left_input_halo, + at::Tensor right_input_halo) { + auto stream = at::cuda::getCurrentCUDAStream(); + ncclGroupStart(); + ncclDataType_t ncclType = get_nccl_type(left_output_halo); + bool left_zero = (left_rank < 0); + bool right_zero = (right_rank < 0); + size_t left_n = torch::numel(left_output_halo); + size_t right_n = torch::numel(right_output_halo); + assert(left_n > 0 && left_n == right_n); + if (left_zero) { + left_input_halo.zero_(); + } else { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), + "left_halo_exch", [&]() { + // send left (to my_rank - 1) + ncclSend(left_output_halo.data_ptr(), left_n, ncclType, left_rank, comm, stream); + // receive left (from my_rank - 1) + ncclRecv(left_input_halo.data_ptr(), right_n, ncclType, left_rank, comm, stream); + }); + } + if (right_zero) { + right_input_halo.zero_(); + } else { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), + "right_halo_exch", [&]() { + // send right (to my_rank + 1 ) + ncclSend(right_output_halo.data_ptr(), right_n, ncclType, right_rank, comm, stream); + // receive right (from my_rank + 1) + ncclRecv(right_input_halo.data_ptr(), left_n, ncclType, right_rank, comm, stream); + }); + } + ncclGroupEnd(); + } + + std::vector left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo) { + // after halo exchange: + // left_output_halo of rank+1 ends up in right_input_halo of rank + // right_output_halo of rank-1 ends up in left_input_halo of rank + auto right_input_halo = torch::empty_like(left_output_halo); + auto left_input_halo = torch::empty_like(right_output_halo); + left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, + right_input_halo); + return {left_input_halo, right_input_halo}; + } }; -class ManagedObjects -{ - public: - ManagedObjects() - { - } - ~ManagedObjects() - { - for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it) - { - delete *it; - } - } - - int add_comm(NcclCommWrapper* comm) - { - int handle = _nccl_comms.size(); - _nccl_comms.push_back(comm); - return handle; - } - - NcclCommWrapper& get_comm(int handle) - { - assert(handle >= 0 && handle < _nccl_comms.size()); - return *_nccl_comms[handle]; - } - - private: - std::vector _nccl_comms; +class ManagedObjects { + public: + ManagedObjects() {} + ~ManagedObjects() { + for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it) { + delete *it; + } + } + + int add_comm(NcclCommWrapper* comm) { + int handle = _nccl_comms.size(); + _nccl_comms.push_back(comm); + return handle; + } + + NcclCommWrapper& get_comm(int handle) { + assert(handle >= 0 && handle < _nccl_comms.size()); + return *_nccl_comms[handle]; + } + + private: + std::vector _nccl_comms; }; class ManagedObjects mo; -} // end anonymous namespace +} // end anonymous namespace -namespace apex { namespace contrib { namespace nccl_p2p { +namespace apex { +namespace contrib { +namespace nccl_p2p { -at::Tensor get_unique_nccl_id(int n) -{ +at::Tensor get_unique_nccl_id(int n) { + ncclUniqueId id; + ncclGetUniqueId(&id); + auto id_tensor = torch::empty({n, (int)sizeof(ncclUniqueId)}, + torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); + auto id_ptr = id_tensor.data_ptr(); + size_t offset = 0; + for (int i = 0; i < n; ++i) { ncclUniqueId id; ncclGetUniqueId(&id); - auto id_tensor = torch::empty({n,(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); - auto id_ptr = id_tensor.data_ptr(); - size_t offset = 0; - for (int i = 0; i < n; ++i) - { - ncclUniqueId id; - ncclGetUniqueId(&id); - memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId)); - offset += sizeof(ncclUniqueId); - } - return id_tensor; + memcpy(id_ptr + offset, &id, sizeof(ncclUniqueId)); + offset += sizeof(ncclUniqueId); + } + return id_tensor; } -int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) -{ - ncclUniqueId id; - auto unique_nccl_id_ptr = unique_nccl_id.data_ptr(); - memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId)); - NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks); - int handle = mo.add_comm(comm); - comm = 0L; - return handle; +int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) { + ncclUniqueId id; + auto unique_nccl_id_ptr = unique_nccl_id.data_ptr(); + memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId)); + NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks); + int handle = mo.add_comm(comm); + comm = 0L; + return handle; } -void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) -{ - class NcclCommWrapper& communicator = mo.get_comm(handle); - return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); +void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo, at::Tensor left_input_halo, + at::Tensor right_input_halo) { + class NcclCommWrapper& communicator = mo.get_comm(handle); + return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, + left_input_halo, right_input_halo); } -std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) -{ - class NcclCommWrapper& communicator = mo.get_comm(handle); - return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo); +std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo) { + class NcclCommWrapper& communicator = mo.get_comm(handle); + return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo); } -void add_delay(int delay) -{ - auto stream = at::cuda::getCurrentCUDAStream(); - auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr()); +void add_delay(int delay) { + auto stream = at::cuda::getCurrentCUDAStream(); + auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + AddDelay_kernel<<<1, 1, 0, stream>>>(delay, t.data_ptr()); } -}}} +} // namespace nccl_p2p +} // namespace contrib +} // namespace apex diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh index 6d29420b2..a047bedb6 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh @@ -19,27 +19,18 @@ #ifndef _nccl_p2p_h_ #define _nccl_p2p_h_ -namespace apex { namespace contrib { namespace nccl_p2p { +namespace apex { +namespace contrib { +namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n); -int init_nccl_comm( - at::Tensor unique_nccl_id, - int my_rank, - int num_ranks - ); -void left_right_halo_exchange_inplace( - int handle, - int left_rank, - int right_rank, - at::Tensor left_output_halo, - at::Tensor right_output_halo, - at::Tensor left_input_halo, - at::Tensor right_input_halo); -std::vector left_right_halo_exchange( - int handle, - int left_rank, - int right_rank, - at::Tensor left_output_halo, - at::Tensor right_output_halo); +int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks); +void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo, at::Tensor left_input_halo, + at::Tensor right_input_halo); +std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, + at::Tensor right_output_halo); void add_delay(int delay); -}}} +} // namespace nccl_p2p +} // namespace contrib +} // namespace apex #endif diff --git a/apex/contrib/csrc/nccl_p2p/nccl_version.cpp b/apex/contrib/csrc/nccl_p2p/nccl_version.cpp index ad03666e9..bb388630b 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_version.cpp +++ b/apex/contrib/csrc/nccl_p2p/nccl_version.cpp @@ -1,11 +1,9 @@ // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. // This file is used to check the version of NCCL detected. -#include - #include +#include + std::tuple get_nccl_version(); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_nccl_version", &get_nccl_version); -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_nccl_version", &get_nccl_version); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu b/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu index 37f9d4590..ed01d4126 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu +++ b/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu @@ -1,10 +1,8 @@ // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. // This file is used to check the version of NCCL detected. -#include #include +#include -std::tuple get_nccl_version() { - return { int(NCCL_MAJOR), int(NCCL_MINOR) }; -} +std::tuple get_nccl_version() { return {int(NCCL_MAJOR), int(NCCL_MINOR)}; } diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp index 001bbe6db..9d69ba01b 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp @@ -1,86 +1,105 @@ #include // CUDA forward declaration -void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first); +void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first); -void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); -void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); -void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); +void fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, + float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, + float decay); +void fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay); +void fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay); -void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); +void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists, + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay); -void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out); +void maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out); void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) // C++ interface -void strided_check_finite( - at::Tensor& overflow_flag, - at::Tensor& p_copy, - int stride, - int clear_overflow_first - ) { - CHECK_INPUT(p_copy); - fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first); +void strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) { + CHECK_INPUT(p_copy); + fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first); } -void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - if (p_copy.numel() > 0) CHECK_INPUT(p_copy); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); +void adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, + float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { + CHECK_INPUT(p); + if (p_copy.numel() > 0) CHECK_INPUT(p_copy); + CHECK_INPUT(m); + CHECK_INPUT(v); + CHECK_INPUT(g); + int64_t num_elem = p.numel(); + TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); + TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); + TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); + TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, + "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); - fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); + fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); } -void reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - if (p_copy.numel() > 0) CHECK_INPUT(p_copy); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); +void reversible_adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, + float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, + float decay) { + CHECK_INPUT(p); + if (p_copy.numel() > 0) CHECK_INPUT(p_copy); + CHECK_INPUT(m); + CHECK_INPUT(v); + CHECK_INPUT(g); + int64_t num_elem = p.numel(); + TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); + TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); + TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); + TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, + "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); - fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); + fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); } -void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); +void maybe_adam_undo(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, + float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, + float decay) { + CHECK_INPUT(p); + CHECK_INPUT(m); + CHECK_INPUT(v); + CHECK_INPUT(g); + int64_t num_elem = p.numel(); + TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); + TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); + TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); + fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, + decay); } -void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) { - CHECK_INPUT(p_in); - CHECK_INPUT(p_out); - int64_t num_elem = p_in.numel(); - TORCH_CHECK(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); +void maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) { + CHECK_INPUT(p_in); + CHECK_INPUT(p_out); + int64_t num_elem = p_in.numel(); + TORCH_CHECK(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); - maybe_cast_cuda(overflow_flag, p_in, p_out); + maybe_cast_cuda(overflow_flag, p_in, p_out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", py::call_guard()); - m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); - m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", py::call_guard()); - m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", py::call_guard()); - m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", py::call_guard()); - m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); - m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); + m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", + py::call_guard()); + m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); + m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", + py::call_guard()); + m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", + py::call_guard()); + m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", + py::call_guard()); + m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", + py::call_guard()); + m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", + py::call_guard()); } diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu index 60ac5b4f7..252750c38 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu @@ -1,1037 +1,764 @@ #include #include #include + #include #include "ATen/ATen.h" +#include "ATen/TensorUtils.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/detail/IndexUtils.cuh" -#include "ATen/TensorUtils.h" // #include "ATen/Type.h" #include "ATen/AccumulateType.h" - #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; +template +__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } #include "type_shim.h" -typedef enum{ - ADAM_MODE_0 =0, // eps under square root - ADAM_MODE_1 =1 // eps outside square root +typedef enum { + ADAM_MODE_0 = 0, // eps under square root + ADAM_MODE_1 = 1 // eps outside square root } adamMode_t; template -__global__ void adam_cuda_kernel( - T* __restrict__ p, - GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - for (int j = i; j < tsize; j+=totThreads) { - T scaled_grad = g[j]/grad_scale; - m[j] = b1*m[j] + (1-b1)*scaled_grad; - v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(v[j] + eps); - else // Mode 1 - denom = sqrtf(v[j]) + eps; - float update = (m[j]/denom) + (decay*p[j]); - p[j] = p[j] - (step_size*update); - if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; - } +__global__ void adam_cuda_kernel(T* __restrict__ p, + GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed + T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, + const float b2, const float eps, const float grad_scale, const float step_size, + const size_t tsize, adamMode_t mode, const float decay) { + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; + + for (int j = i; j < tsize; j += totThreads) { + T scaled_grad = g[j] / grad_scale; + m[j] = b1 * m[j] + (1 - b1) * scaled_grad; + v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(v[j] + eps); + else // Mode 1 + denom = sqrtf(v[j]) + eps; + float update = (m[j] / denom) + (decay * p[j]); + p[j] = p[j] - (step_size * update); + if (p_copy != NULL) p_copy[j] = (GRAD_T)p[j]; + } } template -struct AdamFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - adamMode_t mode, - const float decay) - { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* p = (T *)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; - T* m = (T *)tl.addresses[1][tensor_loc]; - m += chunk_idx*chunk_size; - T* v = (T *)tl.addresses[2][tensor_loc]; - v += chunk_idx*chunk_size; - GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; - g += chunk_idx*chunk_size; - GRAD_T* p_copy = NULL; - if (DEPTH == 5) { - p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; - p_copy += chunk_idx*chunk_size; - } +struct AdamFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata& tl, + const float b1, const float b2, const float eps, const float grad_scale, + const float step_size, adamMode_t mode, const float decay) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* p = (T*)tl.addresses[0][tensor_loc]; + p += chunk_idx * chunk_size; + T* m = (T*)tl.addresses[1][tensor_loc]; + m += chunk_idx * chunk_size; + T* v = (T*)tl.addresses[2][tensor_loc]; + v += chunk_idx * chunk_size; + GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; + g += chunk_idx * chunk_size; + GRAD_T* p_copy = NULL; + if (DEPTH == 5) { + p_copy = (GRAD_T*)tl.addresses[4][tensor_loc]; + p_copy += chunk_idx * chunk_size; + } - n -= chunk_idx*chunk_size; - - T incoming_p[ILP]; - T incoming_m[ILP]; - T incoming_v[ILP]; - T incoming_g[ILP]; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(m) && - is_aligned(v) && - is_aligned(g) && - is_aligned(p_copy)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - GRAD_T tmp_g[ILP]; - load_store(incoming_p, p, 0, i_start); - load_store(incoming_m, m, 0, i_start); - load_store(incoming_v, v, 0, i_start); - load_store(tmp_g, g, 0, i_start); + n -= chunk_idx * chunk_size; + + T incoming_p[ILP]; + T incoming_m[ILP]; + T incoming_v[ILP]; + T incoming_g[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(m) && is_aligned(v) && is_aligned(g) && + is_aligned(p_copy)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + // load + GRAD_T tmp_g[ILP]; + load_store(incoming_p, p, 0, i_start); + load_store(incoming_m, m, 0, i_start); + load_store(incoming_v, v, 0, i_start); + load_store(tmp_g, g, 0, i_start); #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - incoming_g[ii] = static_cast(tmp_g[ii]); - T scaled_grad = incoming_g[ii]/grad_scale; - incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(incoming_v[ii] + eps); - else // Mode 1 - denom = sqrtf(incoming_v[ii]) + eps; - float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]); - incoming_p[ii] = incoming_p[ii] - (step_size*update); - if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); - } - load_store(p, incoming_p, i_start, 0); - load_store(m, incoming_m, i_start, 0); - load_store(v, incoming_v, i_start, 0); - if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); + for (int ii = 0; ii < ILP; ii++) { + incoming_g[ii] = static_cast(tmp_g[ii]); + T scaled_grad = incoming_g[ii] / grad_scale; + incoming_m[ii] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad; + incoming_v[ii] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(incoming_v[ii] + eps); + else // Mode 1 + denom = sqrtf(incoming_v[ii]) + eps; + float update = (incoming_m[ii] / denom) + (decay * incoming_p[ii]); + incoming_p[ii] = incoming_p[ii] - (step_size * update); + if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); + } + load_store(p, incoming_p, i_start, 0); + load_store(m, incoming_m, i_start, 0); + load_store(v, incoming_v, i_start, 0); + if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + incoming_p[ii] = 0; + incoming_m[ii] = 0; + incoming_v[ii] = 0; + incoming_g[ii] = 0; + + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_p[ii] = p[i]; + incoming_m[ii] = m[i]; + incoming_v[ii] = v[i]; + incoming_g[ii] = static_cast(g[i]); } } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) { + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - incoming_p[ii] = 0; - incoming_m[ii] = 0; - incoming_v[ii] = 0; - incoming_g[ii] = 0; - - int i = i_start + threadIdx.x + ii*blockDim.x; - if (i < n && i < chunk_size) { - incoming_p[ii] = p[i]; - incoming_m[ii] = m[i]; - incoming_v[ii] = v[i]; - incoming_g[ii] = static_cast(g[i]); - } - } - - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = i_start + threadIdx.x + ii*blockDim.x; - - if(j < n && j < chunk_size) { - T scaled_grad = incoming_g[ii]/grad_scale; - m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(v[j] + eps); - else // Mode 1 - denom = sqrtf(v[j]) + eps; - float update = (m[j]/denom) + (decay*incoming_p[ii]); - p[j] = incoming_p[ii] - (step_size*update); - if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; - } - } + for (int ii = 0; ii < ILP; ii++) { + int j = i_start + threadIdx.x + ii * blockDim.x; + + if (j < n && j < chunk_size) { + T scaled_grad = incoming_g[ii] / grad_scale; + m[j] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad; + v[j] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(v[j] + eps); + else // Mode 1 + denom = sqrtf(v[j]) + eps; + float update = (m[j] / denom) + (decay * incoming_p[ii]); + p[j] = incoming_p[ii] - (step_size * update); + if (DEPTH == 5) p_copy[j] = (GRAD_T)p[j]; } } + } } + } }; -void fused_adam_cuda( - at::Tensor & p, - at::Tensor & p_copy, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ -// using namespace at; - - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half) { -//all other values should be fp32 for half gradients - TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); -//dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - adam_cuda_kernel<<>>( - p.data_ptr(), - p_copy.numel() ? p_copy.data_ptr() : NULL, - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - adam_cuda_kernel<<>>( - p.data_ptr(), - NULL, //don't output p_copy for fp32, it's wasted write - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); - +void fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, + float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, + float decay) { + // using namespace at; + + // Get tensor size + int tsize = p.numel(); + // Determine #threads and #blocks + const int threadsPerBlock = 512; + const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); + TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); + // Constants + float step_size = 0; + if (bias_correction == 1) { + const float bias_correction1 = 1 - std::pow(beta1, step); + const float bias_correction2 = 1 - std::pow(beta2, step); + step_size = lr * std::sqrt(bias_correction2) / bias_correction1; + } else { + step_size = lr; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (g.scalar_type() == at::ScalarType::Half) { + // all other values should be fp32 for half gradients + TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); + // dispatch is done on the gradient type + using namespace at; // prevents "toString is undefined" errors + DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; + adam_cuda_kernel<<>>( + p.data_ptr(), p_copy.numel() ? p_copy.data_ptr() : NULL, + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, + beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); + } else { + using namespace at; + DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", + adam_cuda_kernel<<>>( + p.data_ptr(), + NULL, // don't output p_copy for fp32, it's wasted write + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, + beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); + } + C10_CUDA_CHECK(cudaGetLastError()); } -void fused_adam_cuda_mt( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, // p, m, v, g, p_copy - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) { - - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; +void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, // p, m, v, g, p_copy + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay) { + // Constants + float step_size = 0; + if (bias_correction == 1) { + const float bias_correction1 = 1 - std::pow(beta1, step); + const float bias_correction2 = 1 - std::pow(beta2, step); + step_size = lr * std::sqrt(bias_correction2) / bias_correction1; + } else { + step_size = lr; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + size_t tl_sz = tensor_lists.size(); + TORCH_CHECK(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); + + if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { + // alher values should be fp32 for half gradients + TORCH_CHECK(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); + // dich is done on the gradient type + if (tl_sz == 5) { + DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + using accscalar_t = at::acc_type; + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor<5, accscalar_t, scalar_t_0>(), beta1, beta2, eps, + grad_scale, step_size, (adamMode_t)mode, decay);); + } else { + DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + using accscalar_t = at::acc_type; + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor<4, accscalar_t, scalar_t_0>(), beta1, beta2, eps, + grad_scale, step_size, (adamMode_t)mode, decay);); } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - size_t tl_sz = tensor_lists.size(); - TORCH_CHECK(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); - - if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { -//alher values should be fp32 for half gradients - TORCH_CHECK(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); -//dich is done on the gradient type - if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - using accscalar_t = at::acc_type; - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<5, accscalar_t, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } else { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - using accscalar_t = at::acc_type; - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<4, accscalar_t, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } + } else { + if (tl_sz == 5) { + DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor<5, scalar_t_0, scalar_t_0>(), beta1, beta2, eps, + grad_scale, step_size, (adamMode_t)mode, decay);); } else { - if (tl_sz == 5) { - DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<5, scalar_t_0, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } else { - DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<4, scalar_t_0, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } + DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor<4, scalar_t_0, scalar_t_0>(), beta1, beta2, eps, + grad_scale, step_size, (adamMode_t)mode, decay);); } - C10_CUDA_CHECK(cudaGetLastError()); + } + C10_CUDA_CHECK(cudaGetLastError()); } -template -__device__ void convert(const FROM_T vi, TO_T& vo) -{ - vo = static_cast(vi); +template +__device__ void convert(const FROM_T vi, TO_T& vo) { + vo = static_cast(vi); } template <> -__device__ void convert(const float vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = vi; - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; +__device__ void convert(const float vi, uint8_t& vo) { + union S { + float as_float; + int as_int; + }; + S s; + s.as_float = vi; + s.as_int = s.as_int & 0xFF800000; + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_half = static_cast(vi + s.as_float / 8.0f); + vo = t.as_byte[1]; } template <> -__device__ void convert(const uint8_t vi, float& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = static_cast(t.as_half); +__device__ void convert(const uint8_t vi, float& vo) { + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_byte[0] = 0; + t.as_byte[1] = vi; + vo = static_cast(t.as_half); } template <> -__device__ void convert(const at::Half vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = static_cast(vi); - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; +__device__ void convert(const at::Half vi, uint8_t& vo) { + union S { + float as_float; + int as_int; + }; + S s; + s.as_float = static_cast(vi); + s.as_int = s.as_int & 0xFF800000; + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_half = static_cast(vi + s.as_float / 8.0f); + vo = t.as_byte[1]; } template <> -__device__ void convert(const uint8_t vi, at::Half& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = t.as_half; +__device__ void convert(const uint8_t vi, at::Half& vo) { + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_byte[0] = 0; + t.as_byte[1] = vi; + vo = t.as_half; } template -__global__ void strided_check_finite_cuda_kernel( - volatile int* noop_gmem, - GRAD_T* __restrict__ p_copy, - const size_t tsize, - int stride, - int clear_overflow_first) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride; - - if (clear_overflow_first) { - if (i == 0) { - *noop_gmem = 0; - } - __syncthreads(); +__global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, GRAD_T* __restrict__ p_copy, + const size_t tsize, int stride, int clear_overflow_first) { + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride; + + if (clear_overflow_first) { + if (i == 0) { + *noop_gmem = 0; } + __syncthreads(); + } - for (int j = i; j < tsize; j+=totThreads) { - GRAD_T pi = p_copy[j]; - if (!isfinite(pi)) { - *noop_gmem = 1; - } + for (int j = i; j < tsize; j += totThreads) { + GRAD_T pi = p_copy[j]; + if (!isfinite(pi)) { + *noop_gmem = 1; } + } } template <> -__global__ void strided_check_finite_cuda_kernel( - volatile int* noop_gmem, - uint8_t* __restrict__ p_copy, - const size_t tsize, - int stride, - int clear_overflow_first) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride; - - if (clear_overflow_first) { - if (i == 0) { - *noop_gmem = 0; - } - __syncthreads(); +__global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, uint8_t* __restrict__ p_copy, + const size_t tsize, int stride, int clear_overflow_first) { + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride; + + if (clear_overflow_first) { + if (i == 0) { + *noop_gmem = 0; } - - for (int j = i; j < tsize; j+=totThreads) { - at::Half pi; - convert(p_copy[j], pi); - if (!isfinite(pi)) { - *noop_gmem = 1; - } + __syncthreads(); + } + + for (int j = i; j < tsize; j += totThreads) { + at::Half pi; + convert(p_copy[j], pi); + if (!isfinite(pi)) { + *noop_gmem = 1; } + } } -template -__global__ void maybe_cast_kernel( - volatile int* overflow_flag, - const FROM_T* p_in, - TO_T* p_out, - const size_t tsize) -{ - if (overflow_flag && *overflow_flag != 0) return; +template +__global__ void maybe_cast_kernel(volatile int* overflow_flag, const FROM_T* p_in, TO_T* p_out, const size_t tsize) { + if (overflow_flag && *overflow_flag != 0) return; - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; - FROM_T pi[ILP]; - TO_T po[ILP]; + FROM_T pi[ILP]; + TO_T po[ILP]; - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { + for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - pi[ii] = 0; + for (int ii = 0; ii < ILP; ii++) { + pi[ii] = 0; - int j = j_start + i + totThreads*ii; - if (j < tsize) { - pi[ii] = p_in[j]; - } - } + int j = j_start + i + totThreads * ii; + if (j < tsize) { + pi[ii] = p_in[j]; + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - convert(pi[ii], po[ii]); - } + for (int ii = 0; ii < ILP; ii++) { + convert(pi[ii], po[ii]); + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i + totThreads*ii; - if (j < tsize) { - p_out[j] = po[ii]; - } - } + for (int ii = 0; ii < ILP; ii++) { + int j = j_start + i + totThreads * ii; + if (j < tsize) { + p_out[j] = po[ii]; + } } + } } template __global__ void reversible_adam_cuda_kernel( - T* __restrict__ p, - REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - T mi[ILP]; - T vi[ILP]; - T pi[ILP]; - T gi[ILP]; - - bool overflow = false; - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { + T* __restrict__ p, + REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed + T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, const float b2, const float eps, + const float grad_scale, const float step_size, const size_t tsize, adamMode_t mode, const float decay) { + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; + + T mi[ILP]; + T vi[ILP]; + T pi[ILP]; + T gi[ILP]; + + bool overflow = false; + for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - mi[ii] = T(0); - vi[ii] = T(0); - pi[ii] = T(0); - gi[ii] = GRAD_T(0); - - int j = j_start + i + totThreads*ii; - if (j < tsize) { - pi[ii] = p[j]; - mi[ii] = m[j]; - vi[ii] = v[j]; - gi[ii] = static_cast(g[j]); - } - } + for (int ii = 0; ii < ILP; ii++) { + mi[ii] = T(0); + vi[ii] = T(0); + pi[ii] = T(0); + gi[ii] = GRAD_T(0); + + int j = j_start + i + totThreads * ii; + if (j < tsize) { + pi[ii] = p[j]; + mi[ii] = m[j]; + vi[ii] = v[j]; + gi[ii] = static_cast(g[j]); + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - T scaled_grad = gi[ii]/grad_scale; - if (isfinite(scaled_grad)) { - mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad; - vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(vi[ii] + eps); - else // Mode 1 - denom = sqrtf(vi[ii]) + eps; - float update = (mi[ii]/denom) + (decay*pi[ii]); - pi[ii] = pi[ii] - (step_size*update); - } else { - overflow = true; - } - } + for (int ii = 0; ii < ILP; ii++) { + T scaled_grad = gi[ii] / grad_scale; + if (isfinite(scaled_grad)) { + mi[ii] = b1 * mi[ii] + (1 - b1) * scaled_grad; + vi[ii] = b2 * vi[ii] + (1 - b2) * scaled_grad * scaled_grad; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(vi[ii] + eps); + else // Mode 1 + denom = sqrtf(vi[ii]) + eps; + float update = (mi[ii] / denom) + (decay * pi[ii]); + pi[ii] = pi[ii] - (step_size * update); + } else { + overflow = true; + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i + totThreads*ii; - if (j < tsize) { - m[j] = mi[ii]; - v[j] = vi[ii]; - p[j] = pi[ii]; - if (p_copy != NULL) { - convert(pi[ii], p_copy[j]); - } - } + for (int ii = 0; ii < ILP; ii++) { + int j = j_start + i + totThreads * ii; + if (j < tsize) { + m[j] = mi[ii]; + v[j] = vi[ii]; + p[j] = pi[ii]; + if (p_copy != NULL) { + convert(pi[ii], p_copy[j]); } + } } + } - if (p_copy != NULL) { - __syncthreads(); - if (overflow) { - convert(float(INFINITY), p_copy[0]); - } + if (p_copy != NULL) { + __syncthreads(); + if (overflow) { + convert(float(INFINITY), p_copy[0]); } + } } template -__global__ void maybe_adam_undo_cuda_kernel( - volatile int* overflow_flag, - T* __restrict__ p, - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - // NB! Skip undo kernel when overflow flag is NOT set - if (overflow_flag && *overflow_flag == 0) return; - - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - T mi[ILP]; - T vi[ILP]; - T pi[ILP]; - T gi[ILP]; - - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { +__global__ void maybe_adam_undo_cuda_kernel(volatile int* overflow_flag, T* __restrict__ p, T* __restrict__ m, + T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, + const float b2, const float eps, const float grad_scale, + const float step_size, const size_t tsize, adamMode_t mode, + const float decay) { + // NB! Skip undo kernel when overflow flag is NOT set + if (overflow_flag && *overflow_flag == 0) return; + + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; + + T mi[ILP]; + T vi[ILP]; + T pi[ILP]; + T gi[ILP]; + + for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - mi[ii] = T(0); - vi[ii] = T(0); - pi[ii] = T(0); - gi[ii] = GRAD_T(0); - - int j = j_start + i*ILP; - if (j < tsize) { - pi[ii] = p[j]; - mi[ii] = m[j]; - vi[ii] = v[j]; - gi[ii] = static_cast(g[j]); - } - } + for (int ii = 0; ii < ILP; ii++) { + mi[ii] = T(0); + vi[ii] = T(0); + pi[ii] = T(0); + gi[ii] = GRAD_T(0); + + int j = j_start + i * ILP; + if (j < tsize) { + pi[ii] = p[j]; + mi[ii] = m[j]; + vi[ii] = v[j]; + gi[ii] = static_cast(g[j]); + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - T scaled_grad = gi[ii]/grad_scale; - if (isfinite(scaled_grad)) { - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(vi[ii] + eps); - else // Mode 1 - denom = sqrtf(vi[ii]) + eps; - pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay); - mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1; - vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2; - // Make sure round off errors don't create (small) negative value. - // This can happen if we have to revert the very first step. - vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f; - } - } + for (int ii = 0; ii < ILP; ii++) { + T scaled_grad = gi[ii] / grad_scale; + if (isfinite(scaled_grad)) { + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(vi[ii] + eps); + else // Mode 1 + denom = sqrtf(vi[ii]) + eps; + pi[ii] = (pi[ii] + step_size * (mi[ii] / denom)) / (1.0f - step_size * decay); + mi[ii] = (mi[ii] - (1 - b1) * scaled_grad) / b1; + vi[ii] = (vi[ii] - (1 - b2) * scaled_grad * scaled_grad) / b2; + // Make sure round off errors don't create (small) negative value. + // This can happen if we have to revert the very first step. + vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f; + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i*ILP; - if (j < tsize) { - m[j] = mi[ii]; - v[j] = vi[ii]; - p[j] = pi[ii]; - } - } + for (int ii = 0; ii < ILP; ii++) { + int j = j_start + i * ILP; + if (j < tsize) { + m[j] = mi[ii]; + v[j] = vi[ii]; + p[j] = pi[ii]; + } } + } } template -struct MaybeCastFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* overflow_flag, - TensorListMetadata& tl) - { - if (overflow_flag && *overflow_flag != 0) return; +struct MaybeCastFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* overflow_flag, + TensorListMetadata& tl) { + if (overflow_flag && *overflow_flag != 0) return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; - FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc]; - p_in += chunk_idx*chunk_size; - TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc]; - p_out += chunk_idx*chunk_size; + FROM_T* p_in = (FROM_T*)tl.addresses[0][tensor_loc]; + p_in += chunk_idx * chunk_size; + TO_T* p_out = (TO_T*)tl.addresses[1][tensor_loc]; + p_out += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; - int dim = chunk_size < n ? chunk_size : n; + n -= chunk_idx * chunk_size; + int dim = chunk_size < n ? chunk_size : n; - FROM_T pi[ILP]; - TO_T po[ILP]; + FROM_T pi[ILP]; + TO_T po[ILP]; - for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) { + for (int j_start = 0; j_start < dim; j_start += blockDim.x * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - pi[ii] = FROM_T(0); - int j = j_start + threadIdx.x + ii*blockDim.x; - if (j < dim) { - pi[ii] = p_in[j]; - } - } + for (int ii = 0; ii < ILP; ii++) { + pi[ii] = FROM_T(0); + int j = j_start + threadIdx.x + ii * blockDim.x; + if (j < dim) { + pi[ii] = p_in[j]; + } + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - convert(pi[ii], po[ii]); - } + for (int ii = 0; ii < ILP; ii++) { + convert(pi[ii], po[ii]); + } #pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + threadIdx.x + ii*blockDim.x; - if (j < dim) { - p_out[j] = po[ii]; - } - } + for (int ii = 0; ii < ILP; ii++) { + int j = j_start + threadIdx.x + ii * blockDim.x; + if (j < dim) { + p_out[j] = po[ii]; } + } } + } }; -void fused_strided_check_finite( - at::Tensor & overflow_flag, - at::Tensor & p_copy, - int stride, - int clear_overflow_first) -{ - //Get tensor size - int tsize = p_copy.numel(); - int niter = (tsize + stride - 1) / stride; - - //Determine #threads and #blocks - const int threadsPerBlock = 512; - //In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set. - const dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock); - TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32"); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel", - strided_check_finite_cuda_kernel<<>>( - overflow_flag.data_ptr(), - p_copy.data_ptr(), - tsize, - stride, - clear_overflow_first); - ); - C10_CUDA_CHECK(cudaGetLastError()); +void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) { + // Get tensor size + int tsize = p_copy.numel(); + int niter = (tsize + stride - 1) / stride; + + // Determine #threads and #blocks + const int threadsPerBlock = 512; + // In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set. + const dim3 blocks(clear_overflow_first ? 1 : (niter + threadsPerBlock - 1) / threadsPerBlock); + TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32"); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + using namespace at; // prevents "toString is undefined" errors + DISPATCH_FLOAT_HALF_AND_BYTE( + p_copy.scalar_type(), 0, "check_finite_cuda_kernel", + strided_check_finite_cuda_kernel<<>>( + overflow_flag.data_ptr(), p_copy.data_ptr(), tsize, stride, clear_overflow_first);); + C10_CUDA_CHECK(cudaGetLastError()); } -void fused_reversible_adam_cuda( - at::Tensor & p, - at::Tensor & p_copy, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ -// using namespace at; - - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half) { - //all other values should be fp32 for half gradients - TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); - //dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - reversible_adam_cuda_kernel<<>>( - p.data_ptr(), - p_copy.numel() ? p_copy.data_ptr() : NULL, - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - TORCH_CHECK(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", - using accscalar_t = at::acc_type; - reversible_adam_cuda_kernel<<>>( - p.data_ptr(), - p_copy.data_ptr(), - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - reversible_adam_cuda_kernel<<>>( - p.data_ptr(), - NULL, //don't output p_copy for fp32, it's wasted write - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); +void fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay) { + // using namespace at; + + // Get tensor size + int tsize = p.numel(); + // Determine #threads and #blocks + const int threadsPerBlock = 512; + const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); + TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); + // Constants + float step_size = 0; + if (bias_correction == 1) { + const float bias_correction1 = 1 - std::pow(beta1, step); + const float bias_correction2 = 1 - std::pow(beta2, step); + step_size = lr * std::sqrt(bias_correction2) / bias_correction1; + } else { + step_size = lr; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (g.scalar_type() == at::ScalarType::Half) { + // all other values should be fp32 for half gradients + TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); + // dispatch is done on the gradient type + using namespace at; // prevents "toString is undefined" errors + if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { + DISPATCH_FLOAT_AND_HALF( + g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; + reversible_adam_cuda_kernel<<>>( + p.data_ptr(), p_copy.numel() ? p_copy.data_ptr() : NULL, + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, + grad_scale, step_size, tsize, (adamMode_t)mode, decay);); + } else { + TORCH_CHECK(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); + DISPATCH_FLOAT_AND_HALF( + g.scalar_type(), 0, "adam_cuda_e5m2_kernel", using accscalar_t = at::acc_type; + reversible_adam_cuda_kernel<<>>( + p.data_ptr(), p_copy.data_ptr(), m.data_ptr(), + v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, + (adamMode_t)mode, decay);); + } + } else { + using namespace at; + DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", + reversible_adam_cuda_kernel + <<>>( + p.data_ptr(), + NULL, // don't output p_copy for fp32, it's wasted write + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, + beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); + } + C10_CUDA_CHECK(cudaGetLastError()); } -void maybe_cast_cuda( - at::Tensor & overflow_flag, - at::Tensor & p_in, - at::Tensor & p_out) -{ - //Get tensor size - int tsize = p_in.numel(); - TORCH_CHECK(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()"); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32"); - //Constants - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda" - DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda", - maybe_cast_kernel<<>>( - overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, - p_in.data_ptr(), - p_out.data_ptr(), - tsize); )) - C10_CUDA_CHECK(cudaGetLastError()); +void maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) { + // Get tensor size + int tsize = p_in.numel(); + TORCH_CHECK(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()"); + // Determine #threads and #blocks + const int threadsPerBlock = 512; + const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); + TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32"); + // Constants + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, + "maybe_cast_cuda" DISPATCH_FLOAT_HALF_AND_BYTE( + p_out.scalar_type(), 1, "maybe_cast_cuda", + maybe_cast_kernel<<>>( + overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, + p_in.data_ptr(), p_out.data_ptr(), tsize);)) + C10_CUDA_CHECK(cudaGetLastError()); } -void maybe_cast_cuda_mt( - int chunk_size, - at::Tensor overflow_flag, - std::vector> tensor_lists) // p_in, p_out +void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, + std::vector> tensor_lists) // p_in, p_out { - //Constants - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - size_t tl_sz = tensor_lists.size(); - TORCH_CHECK(tl_sz == 2, "expected tensor lists of size 2"); - - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel", - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - overflow_flag, - tensor_lists, - MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); )) - C10_CUDA_CHECK(cudaGetLastError()); + // Constants + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + size_t tl_sz = tensor_lists.size(); + TORCH_CHECK(tl_sz == 2, "expected tensor lists of size 2"); + + DISPATCH_FLOAT_HALF_AND_BYTE( + tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel", + DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, overflow_flag, tensor_lists, + MaybeCastFunctor<2, scalar_t_0, scalar_t_1>());)) + C10_CUDA_CHECK(cudaGetLastError()); } -void fused_maybe_adam_undo_cuda( - at::Tensor & overflow_flag, - at::Tensor & p, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half) { - //all other values should be fp32 for half gradients - TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); - //dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - maybe_adam_undo_cuda_kernel<<>>( - overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, - p.data_ptr(), - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - maybe_adam_undo_cuda_kernel<<>>( - overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, - p.data_ptr(), - m.data_ptr(), - v.data_ptr(), - g.data_ptr(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); +void fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, + float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, + int bias_correction, float decay) { + // Get tensor size + int tsize = p.numel(); + // Determine #threads and #blocks + const int threadsPerBlock = 512; + const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); + TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); + // Constants + float step_size = 0; + if (bias_correction == 1) { + const float bias_correction1 = 1 - std::pow(beta1, step); + const float bias_correction2 = 1 - std::pow(beta2, step); + step_size = lr * std::sqrt(bias_correction2) / bias_correction1; + } else { + step_size = lr; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (g.scalar_type() == at::ScalarType::Half) { + // all other values should be fp32 for half gradients + TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); + // dispatch is done on the gradient type + using namespace at; // prevents "toString is undefined" errors + DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; + maybe_adam_undo_cuda_kernel + <<>>( + overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, p.data_ptr(), + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, + beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); + } else { + using namespace at; + DISPATCH_DOUBLE_AND_FLOAT( + g.scalar_type(), 0, "adam_cuda_kernel", + maybe_adam_undo_cuda_kernel<<>>( + overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, p.data_ptr(), + m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, + step_size, tsize, (adamMode_t)mode, decay);); + } + C10_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp index 9b43ed7c9..b0aa92082 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp @@ -1,21 +1,11 @@ #include -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - const float global_grad_norm, - const float max_grad_norm); +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, const float epsilon, const int step, + const int bias_correction, const float weight_decay, const int grad_averaging, + const int mode, const float global_grad_norm, const float max_grad_norm); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", py::call_guard()); + m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", + py::call_guard()); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu index e568635d9..8ca291ce3 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu @@ -7,43 +7,30 @@ #include -#include "type_shim.h" #include "multi_tensor_apply.cuh" +#include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode +typedef enum { + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); +std::tuple multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); using MATH_T = float; -template -struct LAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float beta3, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - adamMode_t mode, - const float decay, - const float global_grad_norm, - const float max_global_grad_norm) - { +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float global_grad_norm, const float max_global_grad_norm) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -52,43 +39,38 @@ struct LAMBStage1Functor int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; - float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; + float clipped_global_grad_norm = + global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; + g += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; + p += chunk_idx * chunk_size; T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; + m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; + v += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { r_g[ii] = g[i]; // special ?optimization? for lamb stage 1 if (decay == 0) { r_p[ii] = MATH_T(0); - } - else { + } else { r_p[ii] = p[i]; } r_m[ii] = m[i]; @@ -101,35 +83,31 @@ struct LAMBStage1Functor } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; - } - else { + } else { MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { g[i] = r_p[ii]; m[i] = r_m[ii]; v[i] = r_v[ii]; @@ -141,18 +119,11 @@ struct LAMBStage1Functor // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. -template -struct LAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - const float* per_tensor_param_norm, - const float* per_tensor_update_norm, - const float learning_rate, - const float decay) - { +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl, + const float* per_tensor_param_norm, const float* per_tensor_update_norm, + const float learning_rate, const float decay) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -164,48 +135,39 @@ struct LAMBStage2Functor MATH_T ratio = learning_rate; // apply adaptive learning rate to parameters with non-zero weight decay - if (decay != 0.0) - { + if (decay != 0.0) { float param_norm = per_tensor_param_norm[tensor_num]; float update_norm = per_tensor_update_norm[tensor_num]; ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; } T* update = (T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; + update += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; + p += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_p[ILP]; MATH_T r_update[ILP]; #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { r_p[ii] = p[i]; r_update[ii] = update[i]; } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio * r_update[ii]); + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = r_p[ii] - (ratio * r_update[ii]); } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { p[i] = r_p[ii]; } } @@ -213,23 +175,10 @@ struct LAMBStage2Functor } }; - -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - const float global_grad_norm, - const float max_grad_norm) -{ +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, const float epsilon, const int step, + const int bias_correction, const float weight_decay, const int grad_averaging, + const int mode, const float global_grad_norm, const float max_grad_norm) { using namespace at; // Master weight and 32bit momentum(potentially changing) is not handled by this // So we assume every tensor are all in the same type @@ -245,8 +194,8 @@ void multi_tensor_lamb_cuda( float beta3 = 1.0f; if (grad_averaging == 1) beta3 = 1 - beta1; - std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); - std::vector> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); + std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1); + std::vector> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2); // Compute per tensor param norm auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); @@ -255,40 +204,22 @@ void multi_tensor_lamb_cuda( // Generally this is not a issue since people modify grad in step() method all the time // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LAMBStage1Functor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - bias_correction1, - bias_correction2, - epsilon, - (adamMode_t) mode, - weight_decay, - global_grad_norm, - max_grad_norm); ) + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + LAMBStage1Functor(), beta1, beta2, + beta3, // 1-beta1 or 1 depends on averaging mode + bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, + weight_decay, global_grad_norm, max_grad_norm);) // Compute update norms auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); - std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); + std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2); - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - grad_param_list, - LAMBStage2Functor(), - std::get<1>(param_norm_tuple).data_ptr(), - std::get<1>(update_norm_tuple).data_ptr(), - lr, - weight_decay); ) + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor(), + std::get<1>(param_norm_tuple).data_ptr(), + std::get<1>(update_norm_tuple).data_ptr(), lr, weight_decay);) AT_CUDA_CHECK(cudaGetLastError()); - } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 9c3811f89..29592a6af 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -1,35 +1,31 @@ #include -void multi_tensor_fused_adam_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor grad_scale, - float lr, float beta1, float beta2, float eps, int step, int mode, - int bias_correction, float weight_decay); +void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, float lr, + float beta1, float beta2, float eps, int step, int mode, int bias_correction, + float weight_decay); -void multi_tensor_fused_adam_capturable_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor grad_scale, - at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, - int mode, int bias_correction, float weight_decay); +void multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, + int mode, int bias_correction, float weight_decay); -void multi_tensor_fused_adam_with_param_remainders_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor grad_scale, - float lr, float beta1, float beta2, float eps, int step, int mode, - int bias_correction, float weight_decay); +void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, float lr, float beta1, float beta2, + float eps, int step, int mode, int bias_correction, + float weight_decay); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, "CUDA kernels for multi-tensor Adam, " "with param copy", py::call_guard()); - m.def("multi_tensor_fused_adam_capturable", - &multi_tensor_fused_adam_capturable_cuda, + m.def("multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_cuda, "CUDA kernels for multi-tensor Adam, " "with param copy, capturable for CUDA graph", py::call_guard()); - m.def("multi_tensor_fused_adam_with_param_remainders", - &multi_tensor_fused_adam_with_param_remainders_cuda, + m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, "CUDA kernel for multi-tensor Adam, " "with stored param remainders and param copy", py::call_guard()); diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu index 6480275c8..23228ef5c 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu @@ -21,11 +21,8 @@ __device__ __forceinline__ bool is_aligned(const T* p) { } template -__device__ __forceinline__ void load_store(T* dst, const T* src, - int dst_offset = 0, - int src_offset = 0) { - typedef - typename std::aligned_storage::type LT; +__device__ __forceinline__ void load_store(T* dst, const T* src, int dst_offset = 0, int src_offset = 0) { + typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((const LT*)src)[src_offset]; } @@ -47,11 +44,11 @@ typedef enum { template struct DistAdamFunctor { // Vectorized local compute - __device__ __forceinline__ static void local_step( - T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, - const float beta1, const float beta2, const float beta1_correction, - const float beta2_correction, const float eps, const float lr, - adamMode_t mode, const float weight_decay) { + __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], + const float grad_scale, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) { if (mode == ADAM_MODE_0) { // L2 #pragma unroll for (int ii = 0; ii < ILP; ii++) { @@ -83,12 +80,11 @@ struct DistAdamFunctor { } } - __device__ __forceinline__ void operator()( - int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, - const float* grad_scale_ptr, const float beta1, const float beta2, - const float beta1_correction, const float beta2_correction, - const float eps, const float lr, adamMode_t mode, - const float weight_decay) const { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; @@ -109,11 +105,10 @@ struct DistAdamFunctor { n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; - const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && - is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + const bool aligned = + (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); - for (int i_start = threadIdx.x * ILP; i_start < n; - i_start += blockDim.x * ILP) { + for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { T local_p[ILP]; T local_m[ILP]; T local_v[ILP]; @@ -144,9 +139,8 @@ struct DistAdamFunctor { } // Local compute - local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, - beta1_correction, beta2_correction, eps, lr, mode, - weight_decay); + local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps, + lr, mode, weight_decay); #pragma unroll for (int ii = 0; ii < ILP; ii++) { local_p_out[ii] = static_cast(local_p[ii]); @@ -180,11 +174,11 @@ struct DistAdamFunctor { template struct DistAdamCapturableFunctor { // Vectorized local compute - __device__ __forceinline__ static void local_step( - T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, - const float beta1, const float beta2, const float beta1_correction, - const float beta2_correction, const float eps, const float lr, - adamMode_t mode, const float weight_decay) { + __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], + const float grad_scale, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) { if (mode == ADAM_MODE_0) { // L2 #pragma unroll for (int ii = 0; ii < ILP; ii++) { @@ -216,11 +210,10 @@ struct DistAdamCapturableFunctor { } } - __device__ __forceinline__ void operator()( - int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, - const float* grad_scale_ptr, const float beta1, const float beta2, - const int* step, const int bias_correction, const float eps, - const float* lr, adamMode_t mode, const float weight_decay) const { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const int* step, const int bias_correction, const float eps, + const float* lr, adamMode_t mode, const float weight_decay) const { assert(noop_gmem); assert(grad_scale_ptr); assert(step); @@ -254,11 +247,10 @@ struct DistAdamCapturableFunctor { n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; - const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && - is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + const bool aligned = + (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); - for (int i_start = threadIdx.x * ILP; i_start < n; - i_start += blockDim.x * ILP) { + for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { T local_p[ILP]; T local_m[ILP]; T local_v[ILP]; @@ -289,9 +281,8 @@ struct DistAdamCapturableFunctor { } // Local compute - local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, - beta1_correction, beta2_correction, eps, *lr, mode, - weight_decay); + local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps, + *lr, mode, weight_decay); #pragma unroll for (int ii = 0; ii < ILP; ii++) { local_p_out[ii] = static_cast(local_p[ii]); @@ -326,12 +317,11 @@ struct DistAdamCapturableFunctor { */ template struct DistAdamWithParamRemaindersFunctor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int* noop_gmem, TensorListMetadata<6>& tl, - const float* grad_scale_ptr, const float beta1, const float beta2, - const float beta1_correction, const float beta2_correction, - const float eps, const float lr, adamMode_t mode, - const float weight_decay) const { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<6>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; @@ -354,12 +344,10 @@ struct DistAdamWithParamRemaindersFunctor { n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; - const bool aligned = - (n % ILP == 0 && is_aligned(p_in) && is_aligned(p_rem) && - is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(p_rem) && is_aligned(m) && is_aligned(v) && + is_aligned(g) && is_aligned(p_out)); - for (int i_start = threadIdx.x * ILP; i_start < n; - i_start += blockDim.x * ILP) { + for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { union fp32_or_int162 { float fp32; int16_t int16[2]; @@ -407,10 +395,8 @@ struct DistAdamWithParamRemaindersFunctor { // Local compute using LocalFunctor = DistAdamFunctor; - LocalFunctor::local_step(reinterpret_cast(local_p), local_m, - local_v, local_g, grad_scale, beta1, beta2, - beta1_correction, beta2_correction, eps, lr, - mode, weight_decay); + LocalFunctor::local_step(reinterpret_cast(local_p), local_m, local_v, local_g, grad_scale, beta1, beta2, + beta1_correction, beta2_correction, eps, lr, mode, weight_decay); // Split into BF16 params (rounded-to-nearest) and remainders #pragma unroll @@ -441,11 +427,10 @@ struct DistAdamWithParamRemaindersFunctor { } }; -void multi_tensor_fused_adam_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, // p_in, m, v, g, p_out - at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, - int step, int mode, int bias_correction, float weight_decay) { +void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, + int mode, int bias_correction, float weight_decay) { using namespace at; // Expect p_in, m, v, g, p_out @@ -467,20 +452,17 @@ void multi_tensor_fused_adam_cuda( g_type, 1, "dist_adam_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( p_out_type, 2, "dist_adam_cuda_kernel", - multi_tensor_apply<5>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - DistAdamFunctor(), - grad_scale.data_ptr(), beta1, beta2, beta1_correction, - beta2_correction, eps, lr, (adamMode_t)mode, - weight_decay);))); + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistAdamFunctor(), grad_scale.data_ptr(), + beta1, beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode, + weight_decay);))); C10_CUDA_CHECK(cudaGetLastError()); } -void multi_tensor_fused_adam_capturable_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, // p_in, m, v, g, p_out - at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps, - at::Tensor step, int mode, int bias_correction, float weight_decay) { +void multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps, + at::Tensor step, int mode, int bias_correction, float weight_decay) { using namespace at; // Expect p_in, m, v, g, p_out @@ -496,22 +478,18 @@ void multi_tensor_fused_adam_capturable_cuda( g_type, 1, "dist_adam_capturable_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( p_out_type, 2, "dist_adam_capturable_cuda_kernel", - multi_tensor_apply<5>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - DistAdamCapturableFunctor(), - grad_scale.data_ptr(), beta1, beta2, - step.data_ptr(), bias_correction, eps, - lr.data_ptr(), (adamMode_t)mode, weight_decay);))); + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistAdamCapturableFunctor(), + grad_scale.data_ptr(), beta1, beta2, step.data_ptr(), bias_correction, + eps, lr.data_ptr(), (adamMode_t)mode, weight_decay);))); C10_CUDA_CHECK(cudaGetLastError()); } void multi_tensor_fused_adam_with_param_remainders_cuda( int chunk_size, at::Tensor noop_flag, - std::vector> - tensor_lists, // p_in, p_rem, m, v, g, p_out - at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, - int step, int mode, int bias_correction, float weight_decay) { + std::vector> tensor_lists, // p_in, p_rem, m, v, g, p_out + at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction, + float weight_decay) { using namespace at; // Expect p_in, p_rem, m, v, g, p_out @@ -528,9 +506,7 @@ void multi_tensor_fused_adam_with_param_remainders_cuda( DISPATCH_FLOAT_HALF_AND_BFLOAT( g_type, 0, "dist_adam_with_param_remainders_cuda_kernel", multi_tensor_apply<6>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - DistAdamWithParamRemaindersFunctor(), - grad_scale.data_ptr(), beta1, beta2, - beta1_correction, beta2_correction, eps, lr, - (adamMode_t)mode, weight_decay);); + DistAdamWithParamRemaindersFunctor(), grad_scale.data_ptr(), beta1, + beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode, weight_decay);); C10_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index b2431a13b..f74bebfc2 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -1,32 +1,18 @@ #include -void multi_tensor_lamb_compute_update_term_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_beta3, - at::Tensor per_tensor_bias_correction, - at::Tensor step, - at::Tensor per_tensor_epsilon, - const int mode, - at::Tensor per_tensor_decay, - at::Tensor global_scale, - at::Tensor global_grad_norm, - const float max_grad_norm); +void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, + at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, + at::Tensor step, at::Tensor per_tensor_epsilon, const int mode, + at::Tensor per_tensor_decay, at::Tensor global_scale, + at::Tensor global_grad_norm, const float max_grad_norm); -void multi_tensor_lamb_update_weights_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - at::Tensor update_norm_offset, - at::Tensor learning_rate, - at::Tensor per_tensor_decay, - at::Tensor global_grad_norm, - bool use_nvlamb); +void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, + at::Tensor update_norm_offset, at::Tensor learning_rate, + at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu index cdeea8eab..cb51980c0 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu @@ -7,126 +7,103 @@ #include -#include "type_shim.h" #include "multi_tensor_apply.cuh" +#include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; +template +__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template -__device__ void convert(const FROM_T vi, TO_T& vo) -{ - vo = static_cast(vi); +template +__device__ void convert(const FROM_T vi, TO_T& vo) { + vo = static_cast(vi); } template <> -__device__ void convert(const float vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = vi; - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; +__device__ void convert(const float vi, uint8_t& vo) { + union S { + float as_float; + int as_int; + }; + S s; + s.as_float = vi; + s.as_int = s.as_int & 0xFF800000; + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_half = static_cast(vi + s.as_float / 8.0f); + vo = t.as_byte[1]; } template <> -__device__ void convert(const uint8_t vi, float& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = static_cast(t.as_half); +__device__ void convert(const uint8_t vi, float& vo) { + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_byte[0] = 0; + t.as_byte[1] = vi; + vo = static_cast(t.as_half); } template <> -__device__ void convert(const at::Half vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = static_cast(vi); - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; +__device__ void convert(const at::Half vi, uint8_t& vo) { + union S { + float as_float; + int as_int; + }; + S s; + s.as_float = static_cast(vi); + s.as_int = s.as_int & 0xFF800000; + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_half = static_cast(vi + s.as_float / 8.0f); + vo = t.as_byte[1]; } template <> -__device__ void convert(const uint8_t vi, at::Half& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = t.as_half; +__device__ void convert(const uint8_t vi, at::Half& vo) { + union T { + at::Half as_half; + uint8_t as_byte[2]; + }; + T t; + t.as_byte[0] = 0; + t.as_byte[1] = vi; + vo = t.as_half; } -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode +typedef enum { + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; -template -struct DistOptLAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<5>& tl, - const MATH_T* per_tensor_beta1, - const MATH_T* per_tensor_beta2, - const MATH_T* per_tensor_beta3, - const int* per_tensor_bias_correction, - const int* step, - const MATH_T* per_tensor_epsilon, - adamMode_t mode, - const MATH_T* per_tensor_decay, - const MATH_T* global_scale, - const MATH_T* global_grad_norm, - const float max_grad_norm) - { +template +struct DistOptLAMBStage1Functor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, + const MATH_T* per_tensor_beta1, const MATH_T* per_tensor_beta2, + const MATH_T* per_tensor_beta3, const int* per_tensor_bias_correction, + const int* step, const MATH_T* per_tensor_epsilon, adamMode_t mode, + const MATH_T* per_tensor_decay, const MATH_T* global_scale, + const MATH_T* global_grad_norm, const float max_grad_norm) { // I'd like this kernel to propagate infs/nans. - if (*noop_gmem == 1) - return; + if (*noop_gmem == 1) return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_num = tl.start_tensor_this_launch + tensor_loc; @@ -135,8 +112,8 @@ struct DistOptLAMBStage1Functor float combined_scale = *global_scale; if (max_grad_norm > 0) { - combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6); - combined_scale = *global_scale / std::min((float) 1.0, combined_scale); + combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6); + combined_scale = *global_scale / std::min((float)1.0, combined_scale); } MATH_T beta1 = per_tensor_beta1[tensor_num]; @@ -144,97 +121,84 @@ struct DistOptLAMBStage1Functor MATH_T beta3 = 1 - beta1; MATH_T beta1_correction, beta2_correction; if (per_tensor_bias_correction[tensor_num] == 1) { - beta1_correction = 1 - pow(beta1, *step); - beta2_correction = 1 - pow(beta2, *step); + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); } else { - beta1_correction = (MATH_T) 1.0; - beta2_correction = (MATH_T) 1.0; + beta1_correction = (MATH_T)1.0; + beta2_correction = (MATH_T)1.0; } MATH_T epsilon = per_tensor_epsilon[tensor_num]; MATH_T decay = per_tensor_decay[tensor_num]; GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; + g += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; + p += chunk_idx * chunk_size; T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; + m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; + v += chunk_idx * chunk_size; MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc]; - u += chunk_idx*chunk_size; + u += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(g) && - is_aligned(p) && - is_aligned(m) && - is_aligned(v)) - { + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && is_aligned(p) && is_aligned(m) && is_aligned(v)) { GRAD_T l_g[ILP]; T l_p[ILP]; T l_m[ILP]; T l_v[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); load_store(l_m, m, 0, i_start); load_store(l_v, v, 0, i_start); // unpack #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_g[ii] = l_g[ii]; if (decay == 0) { r_p[ii] = MATH_T(0); - } - else { + } else { r_p[ii] = l_p[ii]; } r_m[ii] = l_m[ii]; r_v[ii] = l_v[ii]; } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { MATH_T scaled_grad = r_g[ii] / combined_scale; // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; + scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; - } - else { + } else { MATH_T scaled_grad = r_g[ii] / combined_scale; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { l_m[ii] = r_m[ii]; l_v[ii] = r_v[ii]; } @@ -243,30 +207,22 @@ struct DistOptLAMBStage1Functor load_store(m, l_m, i_start, 0); load_store(v, l_v, i_start, 0); } - } - else - { + } else { // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { r_g[ii] = g[i]; // special ?optimization? for lamb stage 1 if (decay == 0) { r_p[ii] = MATH_T(0); - } - else { + } else { r_p[ii] = p[i]; } r_m[ii] = m[i]; @@ -279,35 +235,31 @@ struct DistOptLAMBStage1Functor } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { MATH_T scaled_grad = r_g[ii] / combined_scale; // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; + scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; - } - else { + } else { MATH_T scaled_grad = r_g[ii] / combined_scale; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { u[i] = r_p[ii]; m[i] = r_m[ii]; v[i] = r_v[ii]; @@ -320,24 +272,15 @@ struct DistOptLAMBStage1Functor // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. -template -struct DistOptLAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<3>& tl, - const MATH_T* per_tensor_param_norm, - const MATH_T* per_tensor_update_norm, - const long* update_norm_offset, - const MATH_T* learning_rate, - const MATH_T* per_tensor_decay, - const MATH_T* global_grad_norm, - bool use_nvlamb) - { +template +struct DistOptLAMBStage2Functor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl, + const MATH_T* per_tensor_param_norm, const MATH_T* per_tensor_update_norm, + const long* update_norm_offset, const MATH_T* learning_rate, + const MATH_T* per_tensor_decay, const MATH_T* global_grad_norm, + bool use_nvlamb) { // I'd like this kernel to propagate infs/nans. - if (*noop_gmem == 1) - return; + if (*noop_gmem == 1) return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_num = tl.start_tensor_this_launch + tensor_loc; @@ -349,77 +292,61 @@ struct DistOptLAMBStage2Functor MATH_T ratio = *learning_rate; // nvlamb: apply adaptive learning rate to all parameters // otherwise, only apply to those with non-zero weight decay - if (use_nvlamb || (decay != (MATH_T) 0.0)) - { + if (use_nvlamb || (decay != (MATH_T)0.0)) { MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]]; - ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate); + ratio = + (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate); } MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; + update += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; + p += chunk_idx * chunk_size; GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc]; - p_copy += chunk_idx*chunk_size; + p_copy += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(update)) - { + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update)) { T r_p[ILP]; MATH_T r_update[ILP]; GRAD_T r_p_copy[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load load_store(r_p, p, 0, i_start); load_store(r_update, update, 0, i_start); #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = static_cast(r_p[ii]) - (ratio * r_update[ii]); + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = static_cast(r_p[ii]) - (ratio * r_update[ii]); convert(r_p[ii], r_p_copy[ii]); } load_store(p, r_p, i_start, 0); load_store(p_copy, r_p_copy, i_start, 0); } - } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_p[ILP]; MATH_T r_update[ILP]; #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { r_p[ii] = p[i]; r_update[ii] = update[i]; } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_p[ii] = r_p[ii] - (ratio * r_update[ii]); } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { p[i] = r_p[ii]; convert(r_p[ii], p_copy[i]); } @@ -429,78 +356,51 @@ struct DistOptLAMBStage2Functor } }; -void multi_tensor_lamb_compute_update_term_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_beta3, - at::Tensor per_tensor_bias_correction, - at::Tensor step, - at::Tensor per_tensor_epsilon, - const int mode, - at::Tensor per_tensor_decay, - at::Tensor global_scale, - at::Tensor global_grad_norm, - const float max_grad_norm) -{ +void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, + at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, + at::Tensor step, at::Tensor per_tensor_epsilon, const int mode, + at::Tensor per_tensor_decay, at::Tensor global_scale, + at::Tensor global_grad_norm, const float max_grad_norm) { using namespace at; - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistOptLAMBStage1Functor(), - per_tensor_beta1.data_ptr(), - per_tensor_beta2.data_ptr(), - per_tensor_beta3.data_ptr(), - per_tensor_bias_correction.data_ptr(), - step.data_ptr(), - per_tensor_epsilon.data_ptr(), - (adamMode_t) mode, - per_tensor_decay.data_ptr(), - global_scale.data_ptr(), - global_grad_norm.data_ptr(), - max_grad_norm); ))) + DISPATCH_FLOAT_AND_HALF( + tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistOptLAMBStage1Functor(), + per_tensor_beta1.data_ptr(), per_tensor_beta2.data_ptr(), + per_tensor_beta3.data_ptr(), per_tensor_bias_correction.data_ptr(), + step.data_ptr(), per_tensor_epsilon.data_ptr(), (adamMode_t)mode, + per_tensor_decay.data_ptr(), global_scale.data_ptr(), + global_grad_norm.data_ptr(), max_grad_norm);))) AT_CUDA_CHECK(cudaGetLastError()); } -void multi_tensor_lamb_update_weights_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - at::Tensor update_norm_offset, - at::Tensor learning_rate, - at::Tensor per_tensor_decay, - at::Tensor global_grad_norm, - bool use_nvlamb) -{ +void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, + at::Tensor update_norm_offset, at::Tensor learning_rate, + at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb) { using namespace at; - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2", - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2", - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistOptLAMBStage2Functor(), - per_tensor_param_norm.data_ptr(), - per_tensor_update_norm.data_ptr(), - update_norm_offset.data_ptr(), - learning_rate.data_ptr(), - per_tensor_decay.data_ptr(), - global_grad_norm.data_ptr(), - use_nvlamb); ))) + DISPATCH_FLOAT_AND_HALF( + tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2", + DISPATCH_FLOAT_HALF_AND_BYTE( + tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2", + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistOptLAMBStage2Functor(), + per_tensor_param_norm.data_ptr(), + per_tensor_update_norm.data_ptr(), update_norm_offset.data_ptr(), + learning_rate.data_ptr(), per_tensor_decay.data_ptr(), + global_grad_norm.data_ptr(), use_nvlamb);))) AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory.cpp b/apex/contrib/csrc/peer_memory/peer_memory.cpp index 16e189f49..bc19e6206 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory.cpp +++ b/apex/contrib/csrc/peer_memory/peer_memory.cpp @@ -17,13 +17,20 @@ #include "peer_memory_cuda.cuh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", py::call_guard()); - m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); - m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); - m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", py::call_guard()); - m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", py::call_guard()); - m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", py::call_guard()); - m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", py::call_guard()); - m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", py::call_guard()); - m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", py::call_guard()); + m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", + py::call_guard()); + m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); + m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); + m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", + py::call_guard()); + m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", + py::call_guard()); + m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", + py::call_guard()); + m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", + py::call_guard()); + m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", + py::call_guard()); + m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", + py::call_guard()); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu index 675405855..1ca258c51 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu @@ -1,22 +1,23 @@ -#include -#include #include -#include -#include -#include +#include #include +#include + +#include +#include +#include + #include "nccl.h" -#define CUDACHECK(cmd) do { \ - cudaError_t err = cmd; \ - if( err != cudaSuccess ) { \ - char hostname[1024]; \ - gethostname(hostname, 1024); \ - printf("%s: CUDA failure %s:%d '%s'\n", \ - hostname, \ - __FILE__,__LINE__,cudaGetErrorString(err)); \ - } \ -} while(0) +#define CUDACHECK(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + char hostname[1024]; \ + gethostname(hostname, 1024); \ + printf("%s: CUDA failure %s:%d '%s'\n", hostname, __FILE__, __LINE__, cudaGetErrorString(err)); \ + } \ + } while (0) namespace { @@ -30,712 +31,652 @@ void deleter(void* ptr) } */ -template -at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOptions& options, bool channels_last) -{ - size_t size = 1; - std::vector strides(shape.size()); - if (channels_last) { - assert(shape.size() == 4); - strides[0] = shape[1]*shape[2]*shape[3]; - strides[1] = 1; - strides[2] = shape[1]*shape[3]; - strides[3] = shape[1]; - } else { - int idx = strides.size(); - for (auto it = shape.rbegin(); it != shape.rend(); ++it) - { - strides[--idx] = size; - size *= *it; - } +template +at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOptions& options, bool channels_last) { + size_t size = 1; + std::vector strides(shape.size()); + if (channels_last) { + assert(shape.size() == 4); + strides[0] = shape[1] * shape[2] * shape[3]; + strides[1] = 1; + strides[2] = shape[1] * shape[3]; + strides[3] = shape[1]; + } else { + int idx = strides.size(); + for (auto it = shape.rbegin(); it != shape.rend(); ++it) { + strides[--idx] = size; + size *= *it; } - size *= sizeof(T); - // TODO: Implement dynamic reuse of pooled peer memory. - // We provide no deleter function because all peer memory allocations are static in this implementation. - return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); + } + size *= sizeof(T); + // TODO: Implement dynamic reuse of pooled peer memory. + // We provide no deleter function because all peer memory allocations are static in this implementation. + return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); } -void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) -{ - if (t.dim() == 3) { - N = 1; - if (explicit_nhwc) { - C = t.size(2); - H = t.size(0); - W = t.size(1); - } else { - C = t.size(0); - H = t.size(1); - W = t.size(2); - } - } else if (t.dim() == 4) { - if (explicit_nhwc) { - N = t.size(0); - C = t.size(3); - H = t.size(1); - W = t.size(2); - } else { - N = t.size(0); - C = t.size(1); - H = t.size(2); - W = t.size(3); - } +void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) { + if (t.dim() == 3) { + N = 1; + if (explicit_nhwc) { + C = t.size(2); + H = t.size(0); + W = t.size(1); + } else { + C = t.size(0); + H = t.size(1); + W = t.size(2); + } + } else if (t.dim() == 4) { + if (explicit_nhwc) { + N = t.size(0); + C = t.size(3); + H = t.size(1); + W = t.size(2); } else { - printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,int(t.dim())); - assert(t.dim() == 3 || t.dim() == 4); + N = t.size(0); + C = t.size(1); + H = t.size(2); + W = t.size(3); } + } else { + printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n", __FILE__, __LINE__, int(t.dim())); + assert(t.dim() == 3 || t.dim() == 4); + } } -void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W) -{ - if (t.dim() == 3) { - if (explicit_nhwc) { - stride_C = t.stride(2); - stride_H = t.stride(0); - stride_W = t.stride(1); - } else { - stride_C = t.stride(0); - stride_H = t.stride(1); - stride_W = t.stride(2); - } - stride_N = t.size(0)*t.size(1)*t.size(2); - } else if (t.dim() == 4) { - if (explicit_nhwc) { - stride_N = t.stride(0); - stride_C = t.stride(3); - stride_H = t.stride(1); - stride_W = t.stride(2); - } else { - stride_N = t.stride(0); - stride_C = t.stride(1); - stride_H = t.stride(2); - stride_W = t.stride(3); - } +void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W) { + if (t.dim() == 3) { + if (explicit_nhwc) { + stride_C = t.stride(2); + stride_H = t.stride(0); + stride_W = t.stride(1); + } else { + stride_C = t.stride(0); + stride_H = t.stride(1); + stride_W = t.stride(2); + } + stride_N = t.size(0) * t.size(1) * t.size(2); + } else if (t.dim() == 4) { + if (explicit_nhwc) { + stride_N = t.stride(0); + stride_C = t.stride(3); + stride_H = t.stride(1); + stride_W = t.stride(2); } else { - printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim()); - assert(t.dim() == 3 || t.dim() == 4); + stride_N = t.stride(0); + stride_C = t.stride(1); + stride_H = t.stride(2); + stride_W = t.stride(3); } + } else { + printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n", __FILE__, __LINE__, t.dim()); + assert(t.dim() == 3 || t.dim() == 4); + } } -template -inline __device__ void __zero(T* dst) -{ - *dst = T(0); +template +inline __device__ void __zero(T* dst) { + *dst = T(0); } -inline __device__ void __zero(int2* dst) -{ - *dst = {0, 0}; +inline __device__ void __zero(int2* dst) { *dst = {0, 0}; } + +template +inline __device__ void zero_tensor(const int dim0, const int dim1, const int dim2, T* __restrict__ data, + const int data_stride0, const int data_stride1, const int data_stride2, + const int thread_id, const int block_id, const int num_blocks) { + const int global_id = thread_id + block_id * THREADS_PER_CTA; + const int num_threads = num_blocks * THREADS_PER_CTA; + const int count = dim0 * dim1 * dim2; + for (int i = global_id; i < count; i += num_threads) { + int offset; + if (contiguous) { + offset = i; + } else { + const int j2 = i % dim2; + const int k = i / dim2; + const int j1 = k % dim1; + const int j0 = k / dim1; + offset = j0 * data_stride0 + j1 * data_stride1 + j2 * data_stride2; + } + __zero(data + offset); + } } -template -inline __device__ void zero_tensor( - const int dim0, - const int dim1, - const int dim2, - T* __restrict__ data, - const int data_stride0, - const int data_stride1, - const int data_stride2, - const int thread_id, - const int block_id, - const int num_blocks - ) -{ - const int global_id = thread_id + block_id * THREADS_PER_CTA; - const int num_threads = num_blocks * THREADS_PER_CTA; - const int count = dim0 * dim1 * dim2; - for (int i = global_id; i < count; i += num_threads) { - int offset; - if (contiguous) { - offset = i; - } else { - const int j2 = i % dim2; - const int k = i / dim2; - const int j1 = k % dim1; - const int j0 = k / dim1; - offset = j0 * data_stride0 + j1 * data_stride1 + j2 * data_stride2; - } - __zero(data + offset); +template +inline __device__ void push_pull_tensor(const int dim0, const int dim1, const int dim2, const T* __restrict__ data_in, + const int data_in_stride0, const int data_in_stride1, const int data_in_stride2, + T* __restrict__ data_out, const int data_out_stride0, + const int data_out_stride1, const int data_out_stride2, int4* local_peer, + int4* remote_peer, const int thread_id, const int block_id, + const int num_blocks) { + // 128b=16B NVLink flit + // Note: Use last 4B as a semaphore + static_assert(sizeof(T) <= 12); + union Flit { + T payload; + uint uints[4]; + }; + // Communication bit indicates whether flit has been received from + // a remote GPU + constexpr uint communication_mask = 1 << 0; + // Status bit is used to choose the active peer buffer in an + // alternating double buffer scheme. We use buffer 1 if the bits + // match, use buffer 2 if the bits differ, and invert the bit + // after finishing with a buffer. + constexpr uint status_mask = 1 << 1; + + // Split peer memory into two sets of buffers + // Note: Each block owns a THREADS_PER_CTA*2*16B chunk of peer + // memory + const int peer_offset1 = block_id * THREADS_PER_CTA * 2 + thread_id; + const int peer_offset2 = peer_offset1 + THREADS_PER_CTA; + volatile int* local_peer1 = reinterpret_cast(local_peer + peer_offset1); + volatile int* local_peer2 = reinterpret_cast(local_peer + peer_offset2); + volatile int* remote_peer1 = reinterpret_cast(remote_peer + peer_offset1); + volatile int* remote_peer2 = reinterpret_cast(remote_peer + peer_offset2); + + // Iterate through tensor entries + const int num_threads = num_blocks * THREADS_PER_CTA; + const int count = dim0 * dim1 * dim2; + for (int i0 = block_id * THREADS_PER_CTA; i0 < count; i0 += num_threads) { + const int i = i0 + thread_id; + const bool has_data = i < count; + + // Calculate buffer positions + int data_in_offset, data_out_offset; + if (contiguous) { + data_in_offset = i; + data_out_offset = i; + } else { + const int j2 = i % dim2; + const int k = i / dim2; + const int j1 = k % dim1; + const int j0 = k / dim1; + data_in_offset = j0 * data_in_stride0 + j1 * data_in_stride1 + j2 * data_in_stride2; + data_out_offset = j0 * data_out_stride0 + j1 * data_out_stride1 + j2 * data_out_stride2; } -} -template -inline __device__ void push_pull_tensor( - const int dim0, - const int dim1, - const int dim2, - const T* __restrict__ data_in, - const int data_in_stride0, - const int data_in_stride1, - const int data_in_stride2, - T* __restrict__ data_out, - const int data_out_stride0, - const int data_out_stride1, - const int data_out_stride2, - int4* local_peer, - int4* remote_peer, - const int thread_id, - const int block_id, - const int num_blocks - ) -{ - // 128b=16B NVLink flit - // Note: Use last 4B as a semaphore - static_assert(sizeof(T) <= 12); - union Flit { - T payload; - uint uints[4]; - }; - // Communication bit indicates whether flit has been received from - // a remote GPU - constexpr uint communication_mask = 1 << 0; - // Status bit is used to choose the active peer buffer in an - // alternating double buffer scheme. We use buffer 1 if the bits - // match, use buffer 2 if the bits differ, and invert the bit - // after finishing with a buffer. - constexpr uint status_mask = 1 << 1; - - // Split peer memory into two sets of buffers - // Note: Each block owns a THREADS_PER_CTA*2*16B chunk of peer - // memory - const int peer_offset1 = block_id * THREADS_PER_CTA * 2 + thread_id; - const int peer_offset2 = peer_offset1 + THREADS_PER_CTA; - volatile int* local_peer1 = reinterpret_cast(local_peer + peer_offset1); - volatile int* local_peer2 = reinterpret_cast(local_peer + peer_offset2); - volatile int* remote_peer1 = reinterpret_cast(remote_peer + peer_offset1); - volatile int* remote_peer2 = reinterpret_cast(remote_peer + peer_offset2); - - // Iterate through tensor entries - const int num_threads = num_blocks * THREADS_PER_CTA; - const int count = dim0 * dim1 * dim2; - for (int i0 = block_id * THREADS_PER_CTA; i0 < count; i0 += num_threads) { - const int i = i0 + thread_id; - const bool has_data = i < count; - - // Calculate buffer positions - int data_in_offset, data_out_offset; - if (contiguous) { - data_in_offset = i; - data_out_offset = i; - } else { - const int j2 = i % dim2; - const int k = i / dim2; - const int j1 = k % dim1; - const int j0 = k / dim1; - data_in_offset = j0 * data_in_stride0 + j1 * data_in_stride1 + j2 * data_in_stride2; - data_out_offset = j0 * data_out_stride0 + j1 * data_out_stride1 + j2 * data_out_stride2; - } - - // Determine which peer memory buffer to use - // Note: The status bit is not affected by asynchronous - // communication from the remote GPU. - Flit local_message1, local_message2; - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : - "=r"(local_message1.uints[0]), - "=r"(local_message1.uints[1]), - "=r"(local_message1.uints[2]), - "=r"(local_message1.uints[3]) - : "l"(local_peer1) : "memory"); - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : - "=r"(local_message2.uints[0]), - "=r"(local_message2.uints[1]), - "=r"(local_message2.uints[2]), - "=r"(local_message2.uints[3]) - : "l"(local_peer2) : "memory"); - const uint status1 = local_message1.uints[3] & status_mask; - const uint status2 = local_message2.uints[3] & status_mask; - const bool peer1_is_active = (status1 ^ status2) == 0; - volatile int* ox = peer1_is_active ? remote_peer1 : remote_peer2; - volatile int* ix = peer1_is_active ? local_peer1 : local_peer2; - const uint status = peer1_is_active ? status1 : status2; - Flit recv_message = peer1_is_active ? local_message1 : local_message2; - - // Send flit to remote GPU - // Note: Set communication bit and keep status bit - Flit send_message; - if (has_data) { - send_message.payload = data_in[data_in_offset]; - } - send_message.uints[3] = communication_mask | status; - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: - "l"(ox), - "r"(send_message.uints[0]), - "r"(send_message.uints[1]), - "r"(send_message.uints[2]), - "r"(send_message.uints[3]) - : "memory"); - - // Recieve flit from peer - while ((recv_message.uints[3] & communication_mask) == 0) { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : - "=r"(recv_message.uints[0]), - "=r"(recv_message.uints[1]), - "=r"(recv_message.uints[2]), - "=r"(recv_message.uints[3]) - : "l"(ix) : "memory"); - } - if (has_data) { - data_out[data_out_offset] = recv_message.payload; - } - - // Reset semaphore - // Note: Clear communication bit and invert status bit - uint flag = ~status & status_mask; - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: - "l"(ix), - "n"(0), - "n"(0), - "n"(0), - "r"(flag) - : "memory"); - if (i0 + num_threads < count) { - __threadfence_system(); - } + // Determine which peer memory buffer to use + // Note: The status bit is not affected by asynchronous + // communication from the remote GPU. + Flit local_message1, local_message2; + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(local_message1.uints[0]), "=r"(local_message1.uints[1]), "=r"(local_message1.uints[2]), + "=r"(local_message1.uints[3]) + : "l"(local_peer1) + : "memory"); + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(local_message2.uints[0]), "=r"(local_message2.uints[1]), "=r"(local_message2.uints[2]), + "=r"(local_message2.uints[3]) + : "l"(local_peer2) + : "memory"); + const uint status1 = local_message1.uints[3] & status_mask; + const uint status2 = local_message2.uints[3] & status_mask; + const bool peer1_is_active = (status1 ^ status2) == 0; + volatile int* ox = peer1_is_active ? remote_peer1 : remote_peer2; + volatile int* ix = peer1_is_active ? local_peer1 : local_peer2; + const uint status = peer1_is_active ? status1 : status2; + Flit recv_message = peer1_is_active ? local_message1 : local_message2; + + // Send flit to remote GPU + // Note: Set communication bit and keep status bit + Flit send_message; + if (has_data) { + send_message.payload = data_in[data_in_offset]; + } + send_message.uints[3] = communication_mask | status; + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(ox), "r"(send_message.uints[0]), + "r"(send_message.uints[1]), "r"(send_message.uints[2]), "r"(send_message.uints[3]) + : "memory"); + + // Recieve flit from peer + while ((recv_message.uints[3] & communication_mask) == 0) { + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(recv_message.uints[0]), "=r"(recv_message.uints[1]), "=r"(recv_message.uints[2]), + "=r"(recv_message.uints[3]) + : "l"(ix) + : "memory"); + } + if (has_data) { + data_out[data_out_offset] = recv_message.payload; } + + // Reset semaphore + // Note: Clear communication bit and invert status bit + uint flag = ~status & status_mask; + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(ix), "n"(0), "n"(0), "n"(0), "r"(flag) + : "memory"); + if (i0 + num_threads < count) { + __threadfence_system(); + } + } } -template +template #if __CUDA_ARCH__ >= 700 __launch_bounds__(THREADS_PER_CTA) #endif -__global__ void push_pull_halos_1d_kernel( + __global__ void push_pull_halos_1d_kernel( // top halo, - T* toh, int toh_stride0, int toh_stride1, int toh_stride2, // top output halo (local) - const T* tih, int tih_stride0, int tih_stride1, int tih_stride2, // top input halo (local) - int4* tox, // top output transfer buffer (remote peer) - int4* tix, // top input transfer buffer (local peer) + T* toh, int toh_stride0, int toh_stride1, int toh_stride2, // top output halo (local) + const T* tih, int tih_stride0, int tih_stride1, int tih_stride2, // top input halo (local) + int4* tox, // top output transfer buffer (remote peer) + int4* tix, // top input transfer buffer (local peer) // btm halo - T* boh, int boh_stride0, int boh_stride1, int boh_stride2, // btm output halo (local) - const T* bih, int bih_stride0, int bih_stride1, int bih_stride2, // btm input halo (local) - int4* box, // btm output transfer buffer (remote peer) - int4* bix, // btm input transfer buffer (local peer) + T* boh, int boh_stride0, int boh_stride1, int boh_stride2, // btm output halo (local) + const T* bih, int bih_stride0, int bih_stride1, int bih_stride2, // btm input halo (local) + int4* box, // btm output transfer buffer (remote peer) + int4* bix, // btm input transfer buffer (local peer) // dimensions int dim0, int dim1, int dim2, - bool top_first // whether to launch communicate top halo first - ) -{ - const int num_blocks_side = gridDim.x / 2; - const int block_id_side = (blockIdx.x < num_blocks_side - ? blockIdx.x - : blockIdx.x - num_blocks_side); - const bool in_top_block = top_first == (blockIdx.x < num_blocks_side); - if (in_top_block) { - if (top_zero) { - zero_tensor( - dim0, dim1, dim2, - toh, toh_stride0, toh_stride1, toh_stride2, - threadIdx.x, block_id_side, num_blocks_side); - } else { - push_pull_tensor( - dim0, dim1, dim2, - tih, tih_stride0, tih_stride1, tih_stride2, - toh, toh_stride0, toh_stride1, toh_stride2, - tix, tox, - threadIdx.x, block_id_side, num_blocks_side); - } + bool top_first // whether to launch communicate top halo first + ) { + const int num_blocks_side = gridDim.x / 2; + const int block_id_side = (blockIdx.x < num_blocks_side ? blockIdx.x : blockIdx.x - num_blocks_side); + const bool in_top_block = top_first == (blockIdx.x < num_blocks_side); + if (in_top_block) { + if (top_zero) { + zero_tensor(dim0, dim1, dim2, toh, toh_stride0, toh_stride1, toh_stride2, threadIdx.x, + block_id_side, num_blocks_side); } else { - if (btm_zero) { - zero_tensor( - dim0, dim1, dim2, - boh, boh_stride0, boh_stride1, boh_stride2, - threadIdx.x, block_id_side, num_blocks_side); - } else { - push_pull_tensor( - dim0, dim1, dim2, - bih, bih_stride0, bih_stride1, bih_stride2, - boh, boh_stride0, boh_stride1, boh_stride2, - bix, box, - threadIdx.x, block_id_side, num_blocks_side); - } + push_pull_tensor(dim0, dim1, dim2, tih, tih_stride0, tih_stride1, tih_stride2, toh, toh_stride0, + toh_stride1, toh_stride2, tix, tox, threadIdx.x, block_id_side, num_blocks_side); } -} - -__global__ void delay_kernel(int delay_nanoseconds, int* counter) -{ - if (blockIdx.x == 0 && threadIdx.x == 0) { - // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. - int new_counter = 0; - double elapsed = 0; - clock_t start = clock(); - do { - clock_t now = clock(); - elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC; - ++new_counter; - } while (elapsed < (double)delay_nanoseconds); - *counter = new_counter; + } else { + if (btm_zero) { + zero_tensor(dim0, dim1, dim2, boh, boh_stride0, boh_stride1, boh_stride2, threadIdx.x, + block_id_side, num_blocks_side); + } else { + push_pull_tensor(dim0, dim1, dim2, bih, bih_stride0, bih_stride1, bih_stride2, boh, boh_stride0, + boh_stride1, boh_stride2, bix, box, threadIdx.x, block_id_side, num_blocks_side); } + } } +__global__ void delay_kernel(int delay_nanoseconds, int* counter) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. + int new_counter = 0; + double elapsed = 0; + clock_t start = clock(); + do { + clock_t now = clock(); + elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC; + ++new_counter; + } while (elapsed < (double)delay_nanoseconds); + *counter = new_counter; + } } -namespace apex { namespace contrib { namespace peer_memory { +} // namespace -int64_t allocate_raw(int64_t size) -{ - float* ptr = 0L; - cudaMalloc(&ptr, size); - cudaMemset(ptr, 0, size); - return (int64_t)ptr; -} +namespace apex { +namespace contrib { +namespace peer_memory { -void free_raw(int64_t raw) -{ - cudaFree((void*)raw); +int64_t allocate_raw(int64_t size) { + float* ptr = 0L; + cudaMalloc(&ptr, size); + cudaMemset(ptr, 0, size); + return (int64_t)ptr; } -void zero(int64_t raw, int64_t size) -{ - cudaMemset((void*)raw, 0, size); -} +void free_raw(int64_t raw) { cudaFree((void*)raw); } -at::Tensor get_raw_ipc_address(int64_t raw) -{ - cudaIpcMemHandle_t mem_handle; - CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) ); - const int n = sizeof(cudaIpcMemHandle_t); - auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); - auto address_tensor_p = address_tensor.data_ptr(); - memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); - return address_tensor; +void zero(int64_t raw, int64_t size) { cudaMemset((void*)raw, 0, size); } + +at::Tensor get_raw_ipc_address(int64_t raw) { + cudaIpcMemHandle_t mem_handle; + CUDACHECK(cudaIpcGetMemHandle(&mem_handle, (void*)raw)); + const int n = sizeof(cudaIpcMemHandle_t); + auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); + auto address_tensor_p = address_tensor.data_ptr(); + memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); + return address_tensor; } -std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw) -{ - int peer_group_size = ipc_addresses.size(0); - std::vector results(peer_group_size); - for (int i = 0; i < peer_group_size; ++i) { - if (i != peer_rank) { - cudaIpcMemHandle_t mem_handle; - memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr(), sizeof(cudaIpcMemHandle_t)); - void* p = 0L; - CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) ); - results[i] = (int64_t)p; - } else { - results[i] = (int64_t)raw; - } +std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw) { + int peer_group_size = ipc_addresses.size(0); + std::vector results(peer_group_size); + for (int i = 0; i < peer_group_size; ++i) { + if (i != peer_rank) { + cudaIpcMemHandle_t mem_handle; + memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr(), sizeof(cudaIpcMemHandle_t)); + void* p = 0L; + CUDACHECK(cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess)); + results[i] = (int64_t)p; + } else { + results[i] = (int64_t)raw; } - return results; + } + return results; } -at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); +at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) { + return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); } -at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); +at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) { + return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); } -at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); +at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) { + return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); } void push_pull_halos_1d( - bool diagnostics, - bool explicit_nhwc, - int numSM, // number of SMs to use (zero corresponds to all SMs) - int rank, // rank in spatial parallel group - bool top_zero, // if top halo should be zeroed - at::Tensor top_in_halo, // top input halo buffer (in local device memory, sent to top neighbor) - at::Tensor top_in_transfer, // top input transfer buffer (in local peer memory) - at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) - at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) - bool btm_zero, // if btm halo should be zeroed - at::Tensor btm_in_halo, // btm input halo buffer (in local device memory, sent to btm neighbor) - at::Tensor btm_in_transfer, // btm input transfer buffer (in local peer memory) - at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) - at::Tensor btm_out_halo // btm output halo buffer (in local device memory, received from btm neighbor) - ) -{ - // basic checks of inputs - TORCH_CHECK(!(top_zero && btm_zero)); - TORCH_CHECK(top_in_halo.is_cuda()); - TORCH_CHECK(top_out_transfer.is_cuda()); - TORCH_CHECK(top_in_transfer.is_cuda()); - TORCH_CHECK(top_out_halo.is_cuda()); - TORCH_CHECK(btm_in_halo.is_cuda()); - TORCH_CHECK(btm_out_transfer.is_cuda()); - TORCH_CHECK(btm_in_transfer.is_cuda()); - TORCH_CHECK(btm_out_halo.is_cuda()); - - // tensor shapes - int tih_N, tih_C, tih_H, tih_W; - tensor_shape(top_in_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W); - int toh_N, toh_C, toh_H, toh_W; - tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W); - int bih_N, bih_C, bih_H, bih_W; - tensor_shape(btm_in_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W); - int boh_N, boh_C, boh_H, boh_W; - tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W); - TORCH_CHECK(toh_N == tih_N && tih_N == boh_N && boh_N == bih_N && - toh_C == tih_C && tih_C == boh_C && boh_C == bih_C && - toh_H == tih_H && tih_H == boh_H && boh_H == bih_H && - toh_W == tih_W && tih_W == boh_W && boh_W == bih_W); - int NN=toh_N, NC=toh_C, NH=toh_H, NW=toh_W; - if (diagnostics) { - printf("rank %d: NN=%d, NC=%d, NH=%d, NW=%d\n", rank, NN, NC, NH, NW); + bool diagnostics, bool explicit_nhwc, + int numSM, // number of SMs to use (zero corresponds to all SMs) + int rank, // rank in spatial parallel group + bool top_zero, // if top halo should be zeroed + at::Tensor top_in_halo, // top input halo buffer (in local device memory, sent to top neighbor) + at::Tensor top_in_transfer, // top input transfer buffer (in local peer memory) + at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) + at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) + bool btm_zero, // if btm halo should be zeroed + at::Tensor btm_in_halo, // btm input halo buffer (in local device memory, sent to btm neighbor) + at::Tensor btm_in_transfer, // btm input transfer buffer (in local peer memory) + at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) + at::Tensor btm_out_halo // btm output halo buffer (in local device memory, received from btm neighbor) +) { + // basic checks of inputs + TORCH_CHECK(!(top_zero && btm_zero)); + TORCH_CHECK(top_in_halo.is_cuda()); + TORCH_CHECK(top_out_transfer.is_cuda()); + TORCH_CHECK(top_in_transfer.is_cuda()); + TORCH_CHECK(top_out_halo.is_cuda()); + TORCH_CHECK(btm_in_halo.is_cuda()); + TORCH_CHECK(btm_out_transfer.is_cuda()); + TORCH_CHECK(btm_in_transfer.is_cuda()); + TORCH_CHECK(btm_out_halo.is_cuda()); + + // tensor shapes + int tih_N, tih_C, tih_H, tih_W; + tensor_shape(top_in_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W); + int toh_N, toh_C, toh_H, toh_W; + tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W); + int bih_N, bih_C, bih_H, bih_W; + tensor_shape(btm_in_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W); + int boh_N, boh_C, boh_H, boh_W; + tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W); + TORCH_CHECK(toh_N == tih_N && tih_N == boh_N && boh_N == bih_N && toh_C == tih_C && tih_C == boh_C && + boh_C == bih_C && toh_H == tih_H && tih_H == boh_H && boh_H == bih_H && toh_W == tih_W && + tih_W == boh_W && boh_W == bih_W); + int NN = toh_N, NC = toh_C, NH = toh_H, NW = toh_W; + if (diagnostics) { + printf("rank %d: NN=%d, NC=%d, NH=%d, NW=%d\n", rank, NN, NC, NH, NW); + } + TORCH_CHECK(NN == 1); + + // tensor strides + int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W; + tensor_strides(top_in_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); + int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W; + tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); + int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W; + tensor_strides(btm_in_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); + int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W; + tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); + if (diagnostics) { + printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, tih_stride_N, tih_stride_C, tih_stride_H, + tih_stride_W); + printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, toh_stride_N, toh_stride_C, toh_stride_H, + toh_stride_W); + printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, bih_stride_N, bih_stride_C, bih_stride_H, + bih_stride_W); + printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, boh_stride_N, boh_stride_C, boh_stride_H, + boh_stride_W); + } + + // determine if nhwc + bool is_nhwc = (toh_stride_C == 1); + if (diagnostics) { + printf("rank %d: is_nhwc = %s\n", rank, is_nhwc ? "true" : "false"); + } + + // determine if contiguous + bool contiguous = true; + if ((NN - 1) * toh_stride_N + (NC - 1) * toh_stride_C + (NH - 1) * toh_stride_H + (NW - 1) * toh_stride_W != + NN * NC * NH * NW - 1) { + contiguous = false; + } + if ((NN - 1) * boh_stride_N + (NC - 1) * boh_stride_C + (NH - 1) * boh_stride_H + (NW - 1) * boh_stride_W != + NN * NC * NH * NW - 1) { + contiguous = false; + } + if (!top_zero) { + if (toh_stride_N != tih_stride_N || toh_stride_C != tih_stride_C || toh_stride_H != tih_stride_H || + toh_stride_W != tih_stride_W) { + contiguous = false; } - TORCH_CHECK(NN == 1); - - // tensor strides - int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W; - tensor_strides(top_in_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); - int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W; - tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); - int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W; - tensor_strides(btm_in_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); - int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W; - tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); - if (diagnostics) { - printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); - printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); - printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); - printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); - } - - // determine if nhwc - bool is_nhwc = (toh_stride_C == 1); - if (diagnostics) { - printf("rank %d: is_nhwc = %s\n", rank, is_nhwc ? "true" : "false"); - } - - // determine if contiguous - bool contiguous = true; - if ((NN-1)*toh_stride_N + (NC-1)*toh_stride_C + - (NH-1)*toh_stride_H + (NW-1)*toh_stride_W - != NN*NC*NH*NW - 1) { - contiguous = false; - } - if ((NN-1)*boh_stride_N + (NC-1)*boh_stride_C + - (NH-1)*boh_stride_H + (NW-1)*boh_stride_W - != NN*NC*NH*NW - 1) { - contiguous = false; - } - if (!top_zero) { - if (toh_stride_N != tih_stride_N || toh_stride_C != tih_stride_C || - toh_stride_H != tih_stride_H || toh_stride_W != tih_stride_W) { - contiguous = false; - } - } - if (!btm_zero) { - if (boh_stride_N != bih_stride_N || boh_stride_C != bih_stride_C || - boh_stride_H != bih_stride_H || boh_stride_W != bih_stride_W) { - contiguous = false; - } + } + if (!btm_zero) { + if (boh_stride_N != bih_stride_N || boh_stride_C != bih_stride_C || boh_stride_H != bih_stride_H || + boh_stride_W != bih_stride_W) { + contiguous = false; } + } + if (diagnostics) { + printf("rank %d: contiguous = %s\n", rank, contiguous ? "true" : "false"); + } + + // determine whether to communicate top halo first + bool top_first = rank % 2 != 0; + if (diagnostics) { + printf("rank %d: top_first = %s\n", rank, top_first ? "true" : "false"); + } + + // peer memory buffers + int tox_size = top_out_transfer.numel() * top_out_transfer.element_size(); + int tix_size = top_in_transfer.numel() * top_in_transfer.element_size(); + int box_size = btm_out_transfer.numel() * btm_out_transfer.element_size(); + int bix_size = btm_in_transfer.numel() * btm_in_transfer.element_size(); + if (!top_zero) { + TORCH_CHECK(top_out_transfer.is_contiguous()); + TORCH_CHECK(top_in_transfer.is_contiguous()); + TORCH_CHECK(tox_size == tix_size); + } + if (!btm_zero) { + TORCH_CHECK(btm_out_transfer.is_contiguous()); + TORCH_CHECK(btm_in_transfer.is_contiguous()); + TORCH_CHECK(box_size == bix_size); + } + + // figure out launch parameters + int device; + cudaGetDevice(&device); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + if (numSM <= 0 || numSM > prop.multiProcessorCount) { + numSM = prop.multiProcessorCount; + } + auto current_stream = at::cuda::getCurrentCUDAStream(); + dim3 block(THREADS_PER_CTA, 1, 1); + + // helper macros to launch templated kernel +#define LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, TOP_ZERO, BTM_ZERO, KERNEL_ARGS, NUM_ELEMENTS) \ + do { \ + /* kernel configuration */ \ + int numBlocksPerSm; \ + cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ + &numBlocksPerSm, push_pull_halos_1d_kernel, THREADS_PER_CTA, 0); \ + dim3 grid(numSM* numBlocksPerSm, 1, 1); \ + if (grid.x % 2 != 0) { \ + /* require even number of blocks (half for top, half for bottom) */ \ + grid.x -= 1; \ + } \ + if ((grid.x / 2) * THREADS_PER_CTA > NUM_ELEMENTS) { \ + /* only need enough blocks to cover top and bottom halo elements */ \ + grid.x = 2 * ((NUM_ELEMENTS + THREADS_PER_CTA - 1) / THREADS_PER_CTA); \ + } \ + if (!TOP_ZERO) { \ + /* require 2*128b=32B peer memory per thread */ \ + if ((grid.x / 2) * THREADS_PER_CTA * 32 > tox_size) { \ + grid.x = 2 * (tox_size / (THREADS_PER_CTA * 32)); \ + } \ + } \ + if (!BTM_ZERO) { \ + /* require 2*128b=32B peer memory per thread */ \ + if ((grid.x / 2) * THREADS_PER_CTA * 32 > box_size) { \ + grid.x = 2 * (box_size / (THREADS_PER_CTA * 32)); \ + } \ + } \ + TORCH_CHECK(grid.x >= 2); \ + \ + /* launch kernel */ \ + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, \ + KERNEL_ARGS, 0, current_stream); \ + } while (false) +#define LAUNCH_PUSH_PULL_HALO_KERNEL(T, CONTIGUOUS, KERNEL_ARGS, NUM_ELEMENTS) \ + do { \ + if (top_zero) { \ + LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, true, false, KERNEL_ARGS, NUM_ELEMENTS); \ + } else if (btm_zero) { \ + LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, true, KERNEL_ARGS, NUM_ELEMENTS); \ + } else { \ + LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, false, KERNEL_ARGS, NUM_ELEMENTS); \ + } \ + } while (false) + + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&] { if (diagnostics) { - printf("rank %d: contiguous = %s\n", rank, contiguous ? "true" : "false"); + printf("rank %d: size(scalar_t) = %ld\n", rank, sizeof(scalar_t)); } - - // determine whether to communicate top halo first - bool top_first = rank % 2 != 0; - if (diagnostics) { - printf("rank %d: top_first = %s\n", rank, top_first ? "true" : "false"); - } - - // peer memory buffers - int tox_size = top_out_transfer.numel() * top_out_transfer.element_size(); - int tix_size = top_in_transfer.numel() * top_in_transfer.element_size(); - int box_size = btm_out_transfer.numel() * btm_out_transfer.element_size(); - int bix_size = btm_in_transfer.numel() * btm_in_transfer.element_size(); - if (!top_zero) { - TORCH_CHECK(top_out_transfer.is_contiguous()); - TORCH_CHECK(top_in_transfer.is_contiguous()); - TORCH_CHECK(tox_size == tix_size); - } - if (!btm_zero) { - TORCH_CHECK(btm_out_transfer.is_contiguous()); - TORCH_CHECK(btm_in_transfer.is_contiguous()); - TORCH_CHECK(box_size == bix_size); - } - - // figure out launch parameters - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - if (numSM <= 0 || numSM > prop.multiProcessorCount) { - numSM = prop.multiProcessorCount; + scalar_t* toh_p = top_out_halo.data_ptr(); + scalar_t* tih_p = top_in_halo.data_ptr(); + int4* tox_p = reinterpret_cast(top_out_transfer.data_ptr()); + int4* tix_p = reinterpret_cast(top_in_transfer.data_ptr()); + scalar_t* boh_p = btm_out_halo.data_ptr(); + scalar_t* bih_p = btm_in_halo.data_ptr(); + int4* box_p = reinterpret_cast(btm_out_transfer.data_ptr()); + int4* bix_p = reinterpret_cast(btm_in_transfer.data_ptr()); + if (diagnostics) printf("rank %d: choosing halo exchange kernel\n", rank); + + // do int2 vector loads if channel count permits + if (contiguous && (NN * NH * NW * NC * sizeof(scalar_t)) % sizeof(int2) == 0) { + // can do contiguous int2 transfers + if (diagnostics) { + } + toh_stride_N = toh_stride_H = toh_stride_W = toh_stride_C = 1; + tih_stride_N = tih_stride_H = tih_stride_W = tih_stride_C = 1; + boh_stride_N = boh_stride_H = boh_stride_W = boh_stride_C = 1; + bih_stride_N = bih_stride_H = bih_stride_W = bih_stride_C = 1; + NC = (NN * NH * NW * NC * sizeof(scalar_t)) / sizeof(int2); + NN = NH = NW = 1; + if (diagnostics) { + printf("rank %d: launching contiguous int2 halo exchange kernel\n", rank); + printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); + } + void* kernel_args[] = {(int2**)&toh_p, + &toh_stride_H, + &toh_stride_W, + &toh_stride_C, + (int2**)&tih_p, + &tih_stride_H, + &tih_stride_W, + &tih_stride_C, + &tox_p, + &tix_p, + (int2**)&boh_p, + &boh_stride_H, + &boh_stride_W, + &boh_stride_C, + (int2**)&bih_p, + &bih_stride_H, + &bih_stride_W, + &bih_stride_C, + &box_p, + &bix_p, + &NH, + &NW, + &NC, + &top_first}; + int num_elem = NN * NH * NW * NC; + LAUNCH_PUSH_PULL_HALO_KERNEL(int2, true, kernel_args, num_elem); + } else if (is_nhwc && (NC * sizeof(scalar_t)) % sizeof(int2) == 0) { + // can do strided int2 transfers + int divisor = sizeof(int2) / sizeof(scalar_t); + if (diagnostics) { + printf("rank %d: launching strided int2 halo exchange kernel\n", rank); + } + toh_stride_N /= divisor; + toh_stride_H /= divisor; + toh_stride_W /= divisor; + tih_stride_N /= divisor; + tih_stride_H /= divisor; + tih_stride_W /= divisor; + boh_stride_N /= divisor; + boh_stride_H /= divisor; + boh_stride_W /= divisor; + bih_stride_N /= divisor; + bih_stride_H /= divisor; + bih_stride_W /= divisor; + NC /= divisor; + if (diagnostics) { + printf("rank %d: divisor=%d\n", rank, divisor); + printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, tih_stride_N, tih_stride_C, tih_stride_H, + tih_stride_W); + printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, toh_stride_N, toh_stride_C, toh_stride_H, + toh_stride_W); + printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, bih_stride_N, bih_stride_C, bih_stride_H, + bih_stride_W); + printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, boh_stride_N, boh_stride_C, boh_stride_H, + boh_stride_W); + printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); + } + void* kernel_args[] = {(int2**)&toh_p, + &toh_stride_H, + &toh_stride_W, + &toh_stride_C, + (int2**)&tih_p, + &tih_stride_H, + &tih_stride_W, + &tih_stride_C, + &tox_p, + &tix_p, + (int2**)&boh_p, + &boh_stride_H, + &boh_stride_W, + &boh_stride_C, + (int2**)&bih_p, + &bih_stride_H, + &bih_stride_W, + &bih_stride_C, + &box_p, + &bix_p, + &NH, + &NW, + &NC, + &top_first}; + int num_elem = NH * NW * NC; + LAUNCH_PUSH_PULL_HALO_KERNEL(int2, false, kernel_args, num_elem); + } else { + // cannot do int2 transfers + if (diagnostics) { + printf("rank %d: launching non-int2 halo exchange kernel\n", rank); + } + int num_elem = NC * NH * NW; + if (is_nhwc) { + void* kernel_args[] = {&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, &tih_p, &tih_stride_H, + &tih_stride_W, &tih_stride_C, &tox_p, &tix_p, &boh_p, &boh_stride_H, + &boh_stride_W, &boh_stride_C, &bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, + &box_p, &bix_p, &NH, &NW, &NC, &top_first}; + LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); + } else { + void* kernel_args[] = {&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, &tih_p, &tih_stride_C, + &tih_stride_H, &tih_stride_W, &tox_p, &tix_p, &boh_p, &boh_stride_C, + &boh_stride_H, &boh_stride_W, &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, + &box_p, &bix_p, &NC, &NH, &NW, &top_first}; + LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); + } } - auto current_stream = at::cuda::getCurrentCUDAStream(); - dim3 block(THREADS_PER_CTA, 1, 1); - - // helper macros to launch templated kernel -#define LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, TOP_ZERO, BTM_ZERO, KERNEL_ARGS, NUM_ELEMENTS) \ - do { \ - /* kernel configuration */ \ - int numBlocksPerSm; \ - cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ - &numBlocksPerSm, \ - push_pull_halos_1d_kernel, \ - THREADS_PER_CTA, \ - 0); \ - dim3 grid(numSM*numBlocksPerSm,1,1); \ - if (grid.x % 2 != 0) { \ - /* require even number of blocks (half for top, half for bottom) */ \ - grid.x -= 1; \ - } \ - if ((grid.x / 2) * THREADS_PER_CTA > NUM_ELEMENTS) { \ - /* only need enough blocks to cover top and bottom halo elements */ \ - grid.x = 2 * ((NUM_ELEMENTS + THREADS_PER_CTA - 1) / THREADS_PER_CTA); \ - } \ - if (!TOP_ZERO) { \ - /* require 2*128b=32B peer memory per thread */ \ - if ((grid.x / 2) * THREADS_PER_CTA * 32 > tox_size) { \ - grid.x = 2 * (tox_size / (THREADS_PER_CTA * 32)); \ - } \ - } \ - if (!BTM_ZERO) { \ - /* require 2*128b=32B peer memory per thread */ \ - if ((grid.x / 2) * THREADS_PER_CTA * 32 > box_size) { \ - grid.x = 2 * (box_size / (THREADS_PER_CTA * 32)); \ - } \ - } \ - TORCH_CHECK(grid.x >= 2); \ - \ - /* launch kernel */ \ - cudaLaunchCooperativeKernel( \ - (void*)push_pull_halos_1d_kernel, \ - grid, \ - block, \ - KERNEL_ARGS, \ - 0, \ - current_stream); \ - } while (false) -#define LAUNCH_PUSH_PULL_HALO_KERNEL(T, CONTIGUOUS, KERNEL_ARGS, NUM_ELEMENTS) \ - do { \ - if (top_zero) { \ - LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, true, false, KERNEL_ARGS, NUM_ELEMENTS); \ - } else if (btm_zero) { \ - LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, true, KERNEL_ARGS, NUM_ELEMENTS); \ - } else { \ - LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, false, KERNEL_ARGS, NUM_ELEMENTS); \ - } \ - } while (false) - - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{ - if (diagnostics) { - printf("rank %d: size(scalar_t) = %ld\n", rank, sizeof(scalar_t)); - } - scalar_t* toh_p = top_out_halo.data_ptr(); - scalar_t* tih_p = top_in_halo.data_ptr(); - int4* tox_p = reinterpret_cast(top_out_transfer.data_ptr()); - int4* tix_p = reinterpret_cast(top_in_transfer.data_ptr()); - scalar_t* boh_p = btm_out_halo.data_ptr(); - scalar_t* bih_p = btm_in_halo.data_ptr(); - int4* box_p = reinterpret_cast(btm_out_transfer.data_ptr()); - int4* bix_p = reinterpret_cast(btm_in_transfer.data_ptr()); - if (diagnostics) printf("rank %d: choosing halo exchange kernel\n", rank); - - // do int2 vector loads if channel count permits - if (contiguous && - (NN*NH*NW*NC * sizeof(scalar_t)) % sizeof(int2) == 0) { - // can do contiguous int2 transfers - if (diagnostics) { - } - toh_stride_N = toh_stride_H = toh_stride_W = toh_stride_C = 1; - tih_stride_N = tih_stride_H = tih_stride_W = tih_stride_C = 1; - boh_stride_N = boh_stride_H = boh_stride_W = boh_stride_C = 1; - bih_stride_N = bih_stride_H = bih_stride_W = bih_stride_C = 1; - NC = (NN*NH*NW*NC * sizeof(scalar_t)) / sizeof(int2); - NN = NH = NW = 1; - if (diagnostics) { - printf("rank %d: launching contiguous int2 halo exchange kernel\n", - rank); - printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); - } - void *kernel_args[] = { - (int2**)&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, - (int2**)&tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, - &tox_p, &tix_p, - (int2**)&boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, - (int2**)&bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, - &box_p, &bix_p, - &NH, &NW, &NC, - &top_first - }; - int num_elem = NN*NH*NW*NC; - LAUNCH_PUSH_PULL_HALO_KERNEL(int2, true, kernel_args, num_elem); - } else if (is_nhwc && (NC * sizeof(scalar_t)) % sizeof(int2) == 0) { - // can do strided int2 transfers - int divisor = sizeof(int2) / sizeof(scalar_t); - if (diagnostics) { - printf("rank %d: launching strided int2 halo exchange kernel\n", - rank); - } - toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor; - tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor; - boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor; - bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor; - NC /= divisor; - if (diagnostics) { - printf("rank %d: divisor=%d\n", rank, divisor); - printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); - printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); - printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); - printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", - rank, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); - printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); - } - void *kernel_args[] = { - (int2**)&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, - (int2**)&tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, - &tox_p, &tix_p, - (int2**)&boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, - (int2**)&bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, - &box_p, &bix_p, - &NH, &NW, &NC, - &top_first - }; - int num_elem = NH*NW*NC; - LAUNCH_PUSH_PULL_HALO_KERNEL(int2, false, kernel_args, num_elem); - } else { - // cannot do int2 transfers - if (diagnostics) { - printf("rank %d: launching non-int2 halo exchange kernel\n", - rank); - } - int num_elem = NC*NH*NW; - if (is_nhwc) { - void *kernel_args[] = { - &toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, - &tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, - &tox_p, &tix_p, - &boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, - &bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, - &box_p, &bix_p, - &NH, &NW, &NC, - &top_first - }; - LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); - } else { - void *kernel_args[] = { - &toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, - &tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, - &tox_p, &tix_p, - &boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, - &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, - &box_p, &bix_p, - &NC, &NH, &NW, - &top_first - }; - LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); - } - } - } ); + }); #undef LAUNCH_PUSH_PULL_HALO_KERNEL_BASE #undef LAUNCH_PUSH_PULL_HALO_KERNEL } -} } } +} // namespace peer_memory +} // namespace contrib +} // namespace apex diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 83d11a2ca..29cf7a108 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -19,30 +19,33 @@ #ifndef _peer_memory_h_ #define _peer_memory_h_ -namespace apex { namespace contrib { namespace peer_memory { - int64_t allocate_raw(int64_t size); - void free_raw(int64_t raw); - void zero(int64_t raw, int64_t size); - at::Tensor get_raw_ipc_address(int64_t raw); - std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); - at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); - at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); - at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); - void push_pull_halos_1d( - bool diagnostics, - bool explicit_nhwc, - int numSM, // number of SMs to use - int peer_rank, // rank in spatial parallel group - bool top_zero, // if top halo should be zeroed - at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) - at::Tensor top_inp_transfer, // top input transfer buffer (in local peer memory) - at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) - at::Tensor top_inp_halo, // top input halo buffer (in local device memory, sent to top neighbor) - bool btm_zero, // if btm halo should be zeroed - at::Tensor btm_out_halo, // btm output halo buffer (in local device memory, received from btm neighbor) - at::Tensor btm_inp_transfer, // btm input transfer buffer (in local peer memory) - at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) - at::Tensor btm_inp_halo // btm input halo buffer (in local device memory, sent to btm neighbor) - ); -} } } +namespace apex { +namespace contrib { +namespace peer_memory { +int64_t allocate_raw(int64_t size); +void free_raw(int64_t raw); +void zero(int64_t raw, int64_t size); +at::Tensor get_raw_ipc_address(int64_t raw); +std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); +at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); +at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); +at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); +void push_pull_halos_1d( + bool diagnostics, bool explicit_nhwc, + int numSM, // number of SMs to use + int peer_rank, // rank in spatial parallel group + bool top_zero, // if top halo should be zeroed + at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) + at::Tensor top_inp_transfer, // top input transfer buffer (in local peer memory) + at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) + at::Tensor top_inp_halo, // top input halo buffer (in local device memory, sent to top neighbor) + bool btm_zero, // if btm halo should be zeroed + at::Tensor btm_out_halo, // btm output halo buffer (in local device memory, received from btm neighbor) + at::Tensor btm_inp_transfer, // btm input transfer buffer (in local peer memory) + at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) + at::Tensor btm_inp_halo // btm input halo buffer (in local device memory, sent to btm neighbor) +); +} // namespace peer_memory +} // namespace contrib +} // namespace apex #endif diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp old mode 100755 new mode 100644 index 4b8f0dbd0..1175c1676 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -1,98 +1,49 @@ -#include #include +#include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector transducer_joint_cuda_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize); - - -std::vector transducer_joint_cuda_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale); - -std::vector transducer_joint_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize) { - CHECK_INPUT(f); - CHECK_INPUT(g); - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) - CHECK_INPUT(batchOffset); - return transducer_joint_cuda_forward( - f, - g, - fLen, - gLen, - batchOffset, - packedBatch, - opt, - packOutput, - relu, - dropout, - dropoutProb, - tileSize); +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, + int64_t packedBatch, int opt, bool packOutput, bool relu, + bool dropout, float dropoutProb, int tileSize); + +std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale); + +std::vector transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, + int opt, bool packOutput, bool relu, bool dropout, + float dropoutProb, int tileSize) { + CHECK_INPUT(f); + CHECK_INPUT(g); + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); + return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, + dropoutProb, tileSize); } -std::vector transducer_joint_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale) { - for (auto t : in){ - CHECK_INPUT(t); - } - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) - CHECK_INPUT(batchOffset); - return transducer_joint_cuda_backward( - in, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - packOutput, - scale); +std::vector transducer_joint_backward(std::vector in, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale) { + for (auto t : in) { + CHECK_INPUT(t); + } + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); + return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); } - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", py::call_guard()); - m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", py::call_guard()); + m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", + py::call_guard()); + m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", + py::call_guard()); } diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu old mode 100755 new mode 100644 index 1e6a465de..0d643fb02 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -1,9 +1,8 @@ +#include #include -#include #include - +#include #include -#include #ifdef OLD_GENERATOR_PATH #include @@ -12,99 +11,90 @@ #endif #include -#include #include +#include + #include "philox.cuh" // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // width should be a power of 2 and should be less than warpSize. template -__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){ - for (unsigned offset = width/2; offset > 0; offset /= 2){ - x += __shfl_down_sync(0xffffffff, x, offset, width); - } - return x; +__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width = C10_WARP_SIZE) { + for (unsigned offset = width / 2; offset > 0; offset /= 2) { + x += __shfl_down_sync(0xffffffff, x, offset, width); + } + return x; } -inline int largestPowerOfTwo(int x){ - int y = 1; - while (y <= x) - y <<= 1; - return y >> 1; +inline int largestPowerOfTwo(int x) { + int y = 1; + while (y <= x) y <<= 1; + return y >> 1; } /* Figure out vectorization type for masks. Similar to how PyTorch figures out acc_t here: -aten/src/ATen/AccumulateType.h +aten/src/ATen/AccumulateType.h */ template -struct MaskVecType { }; +struct MaskVecType {}; -template <> struct MaskVecType<1> { using type = uint8_t; }; -template <> struct MaskVecType<2> { using type = uint16_t; }; -template <> struct MaskVecType<4> { using type = uint32_t; }; +template <> +struct MaskVecType<1> { + using type = uint8_t; +}; +template <> +struct MaskVecType<2> { + using type = uint16_t; +}; +template <> +struct MaskVecType<4> { + using type = uint32_t; +}; -template +template using mvec_type = typename MaskVecType::type; // Helper class to calculate pointer offset that can be shared by different flavors of kernels. // For fwd, batch offset and stride are different for packing and non-packing mode. -struct OffsetCalFwd{ - __device__ __forceinline__ OffsetCalFwd( - int64_t batch, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t gLen, - int64_t hiddenSize, - bool packOutput) : - batch(batch), +struct OffsetCalFwd { + __device__ __forceinline__ OffsetCalFwd(int64_t batch, const int64_t *batchOffset, int64_t maxFLen, int64_t maxGLen, + int64_t gLen, int64_t hiddenSize, bool packOutput) + : batch(batch), batchOffset(batchOffset), maxFLen(maxFLen), maxGLen(maxGLen), gLen(gLen), hiddenSize(hiddenSize), - packOutput(packOutput) - {} - - int64_t batch; - const int64_t *batchOffset; - int64_t maxFLen; - int64_t maxGLen; - int64_t gLen; - int64_t hiddenSize; - bool packOutput; - - __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize - : batch*maxFLen*maxGLen*hiddenSize; - } - - __device__ __forceinline__ int64_t getStrideF(){ - return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize; - } - - + packOutput(packOutput) {} + + int64_t batch; + const int64_t *batchOffset; + int64_t maxFLen; + int64_t maxGLen; + int64_t gLen; + int64_t hiddenSize; + bool packOutput; + + __device__ __forceinline__ int64_t getBatchOffset() { + return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize + : batch * maxFLen * maxGLen * hiddenSize; + } + + __device__ __forceinline__ int64_t getStrideF() { return packOutput ? gLen * hiddenSize : maxGLen * hiddenSize; } }; // Helper class to calculate pointer offset that can be shared by different flavors of kernels // For bwd, batch offset and stride are different for packing and non-packing mode. // The reducion is done for two input tensors. Therefore, generating two sets of offsets // according to bwdFasterDim can lead to a unified implementation in the actual kernel. -struct OffsetCalBwd{ - __device__ __forceinline__ OffsetCalBwd( - int64_t batch, - const int64_t *batchOffset, - const int *fLen, - const int *gLen, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim) : - batch(batch), +struct OffsetCalBwd { + __device__ __forceinline__ OffsetCalBwd(int64_t batch, const int64_t *batchOffset, const int *fLen, const int *gLen, + int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, + bool bwdFasterDim) + : batch(batch), batchOffset(batchOffset), maxFLen(maxFLen), maxGLen(maxGLen), @@ -112,48 +102,44 @@ struct OffsetCalBwd{ gLen(gLen), hiddenSize(hiddenSize), packOutput(packOutput), - bwdFasterDim(bwdFasterDim) - {} - - int64_t batch; - const int64_t *batchOffset; - const int *fLen; - const int *gLen; - int64_t maxFLen; - int64_t maxGLen; - int64_t hiddenSize; - bool packOutput; - bool bwdFasterDim; // whether doing bwd on the faster moving dimension - - __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize - : batch*maxFLen*maxGLen*hiddenSize; - } - - __device__ __forceinline__ int64_t getMaxXLen(){ - return bwdFasterDim ? maxGLen : maxFLen; - } - - __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){ - return bwdFasterDim ? gLen[batch] : fLen[batch]; - } - - __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){ - return bwdFasterDim ? fLen[batch] : gLen[batch]; - } - - __device__ __forceinline__ int64_t getStrideX(){ - return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize); - } - - __device__ __forceinline__ int64_t getStrideY(){ - return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize; - } + bwdFasterDim(bwdFasterDim) {} + + int64_t batch; + const int64_t *batchOffset; + const int *fLen; + const int *gLen; + int64_t maxFLen; + int64_t maxGLen; + int64_t hiddenSize; + bool packOutput; + bool bwdFasterDim; // whether doing bwd on the faster moving dimension + + __device__ __forceinline__ int64_t getBatchOffset() { + return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize + : batch * maxFLen * maxGLen * hiddenSize; + } + + __device__ __forceinline__ int64_t getMaxXLen() { return bwdFasterDim ? maxGLen : maxFLen; } + + __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]) { + return bwdFasterDim ? gLen[batch] : fLen[batch]; + } + + __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]) { + return bwdFasterDim ? fLen[batch] : gLen[batch]; + } + + __device__ __forceinline__ int64_t getStrideX() { + return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize); + } + + __device__ __forceinline__ int64_t getStrideY() { + return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize; + } }; - // Vanila transducer joint forward kernel -// Detail of this joint function can be found in: +// Detail of this joint function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // f is a tensor of shape [batch, T, H] @@ -168,59 +154,48 @@ struct OffsetCalBwd{ // Don't-care region (t > fLen) or (u > gLen) is removed. // To enable packing, the starting offset for each batch need to be specified with batchOffset. template -__global__ void transducer_joint_forward( - const scalar_t *f, - const scalar_t *g, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - scalar_t *sum) { - - - const int batch = blockIdx.z; - const int t = blockIdx.y; - const int u = blockIdx.x; - const auto myFLen = fLen[batch]; - const auto myGLen = gLen[batch]; - - OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideF = offsetCal.getStrideF(); - scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize; - scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize; - scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize; - - if (t < myFLen and u < myGLen){ - #pragma unroll - for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){ - if (h < hiddenSize){ - mySum[h] = myF[h] + myG[h]; - } - } +__global__ void transducer_joint_forward(const scalar_t *f, const scalar_t *g, const int *fLen, const int *gLen, + const int64_t *batchOffset, int64_t maxFLen, int64_t maxGLen, + int64_t hiddenSize, bool packOutput, scalar_t *sum) { + const int batch = blockIdx.z; + const int t = blockIdx.y; + const int u = blockIdx.x; + const auto myFLen = fLen[batch]; + const auto myGLen = gLen[batch]; + + OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); + const auto myBatchOffset = offsetCal.getBatchOffset(); + const auto strideF = offsetCal.getStrideF(); + scalar_t const *myF = f + batch * maxFLen * hiddenSize + t * hiddenSize; + scalar_t const *myG = g + batch * maxGLen * hiddenSize + u * hiddenSize; + scalar_t *mySum = sum + myBatchOffset + t * strideF + u * hiddenSize; + + if (t < myFLen and u < myGLen) { +#pragma unroll + for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) { + if (h < hiddenSize) { + mySum[h] = myF[h] + myG[h]; + } } - else if (packOutput == false and t < maxFLen and u < maxGLen){ - // Need to write finite data to don't-care region because we instantiate the result tensor - // with torch::empty for performance reasons. Even though it is don't-care region, the - // contents need to be finite, otherwise could lead to NaN in WGRAD. - // In packing mode, this write is no longer necessary as we remove the don't-care region - // from the output. - // Picking -1 (over 0) here for ease of testing. - #pragma unroll - for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){ - if (h < hiddenSize){ - mySum[h] = -1; - } - } + } else if (packOutput == false and t < maxFLen and u < maxGLen) { +// Need to write finite data to don't-care region because we instantiate the result tensor +// with torch::empty for performance reasons. Even though it is don't-care region, the +// contents need to be finite, otherwise could lead to NaN in WGRAD. +// In packing mode, this write is no longer necessary as we remove the don't-care region +// from the output. +// Picking -1 (over 0) here for ease of testing. +#pragma unroll + for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) { + if (h < hiddenSize) { + mySum[h] = -1; + } } + } } /* Tiled version of the joint forward kernel -Detail of this joint function can be found in: +Detail of this joint function can be found in: [1] Sequence Transduction with Recurrent Neural Networks. f is a tensor of shape [batch, T, H] @@ -229,133 +204,115 @@ the transducer joint does sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) The resultant tensor is of shape [batch, T, U, H] Each thread is working on a tile of the shape of tileF x tileG in the result tensor. -The input for the tile is first loaded in the register and is reused tileG and tileF times. +The input for the tile is first loaded in the register and is reused tileG and tileF times. This joint function can optionally pack the output where the output tensor with a shape of [B, T, U, H] is packed into [B_packed, H]. Don't-care region (t > fLen) or (u > gLen) is removed. To enable packing, the starting offset for each batch need to be specified with batchOffset. -Optionally this joint function performs ReLU and/or dropout on the joint output, which is +Optionally this joint function performs ReLU and/or dropout on the joint output, which is controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint -function is a masked operation, which is controlled by the template argument masked. In this case, +function is a masked operation, which is controlled by the template argument masked. In this case, masks are saved to backward. */ template -__global__ void transducer_joint_tiled_forward( - const scalar_t *f, - const scalar_t *g, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - int64_t hiddenPerBlock, - bool packOutput, - bool relu, - bool dropout, - float p, - at::PhiloxCudaState philoxArgs, - scalar_t *sum, - uint8_t *mask) { - - static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4"); - - const int batch = blockIdx.z; - const int t = blockIdx.y * tileF; - const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const int u = blockIdx.x / hiddenBlock * tileG; - const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock; - const int h = threadIdx.x; - const auto myFLen = fLen[batch]; - const auto myGLen = gLen[batch]; - - OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideF = offsetCal.getStrideF(); - - scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset; - scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset; - scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset; - uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset; - - // The following code is only needed for dropout. We try to bypass them as much as possible. - auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) - : std::make_tuple(static_cast(0), static_cast(0)); - uint64_t tid = masked ? (static_cast(blockIdx.z)*gridDim.y*gridDim.x + - blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x - : 0; - Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); - scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; - bool dropoutMask[U]; - - if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ - // register buffers for tiled input reuse - scalar_t fBuffer[tileF], gBuffer[tileG]; - for (int i = 0; i < tileF; ++i){ - if (t + i < myFLen) - fBuffer[i] = myF[i*hiddenSize + h]; - } - for (int j = 0; j < tileG; ++j){ - if (u + j < myGLen) - gBuffer[j] = myG[j*hiddenSize + h]; - } - #pragma unroll - for (int i = 0; i < tileF; ++i){ - if (t + i < myFLen){ - #pragma unroll - for (int j = 0; j < tileG; ++j){ - int idx = i*tileG + j; - if (masked and dropout and idx % U == 0){ - // For performance, generate 4 random numbers in one shot - // auto rand4 = curand_uniform4(&state); - auto rand4 = uniform4(ph()); - dropoutMask[0] = rand4.x < p; - dropoutMask[1] = rand4.y < p; - dropoutMask[2] = rand4.z < p; - dropoutMask[3] = rand4.w < p; - } - - if (u + j < myGLen){ - scalar_t out = fBuffer[i] + gBuffer[j]; - if (masked){ - // Apply ReLU here when relu is True - bool localMask = relu ? (out>0) : 1; - localMask = dropout ? localMask & dropoutMask[idx%U] : localMask; - out = dropout ? out*localMask*scale : out*localMask; - myMask[i*strideF + j*hiddenSize + h] = static_cast(localMask); - } - mySum[i*strideF + j*hiddenSize + h] = out; - } - else if (packOutput == false and u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } - } - else if (packOutput == false and t + i < maxFLen){ - // Again need to write finite data to don't-care region - #pragma unroll - for (int j = 0; j < tileG; ++j){ - if (u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } +__global__ void transducer_joint_tiled_forward(const scalar_t *f, const scalar_t *g, const int *fLen, const int *gLen, + const int64_t *batchOffset, int64_t maxFLen, int64_t maxGLen, + int64_t hiddenSize, int64_t hiddenPerBlock, bool packOutput, bool relu, + bool dropout, float p, at::PhiloxCudaState philoxArgs, scalar_t *sum, + uint8_t *mask) { + static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4"); + + const int batch = blockIdx.z; + const int t = blockIdx.y * tileF; + const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; + const int u = blockIdx.x / hiddenBlock * tileG; + const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock; + const int h = threadIdx.x; + const auto myFLen = fLen[batch]; + const auto myGLen = gLen[batch]; + + OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); + const auto myBatchOffset = offsetCal.getBatchOffset(); + const auto strideF = offsetCal.getStrideF(); + + scalar_t const *myF = f + batch * maxFLen * hiddenSize + t * hiddenSize + hOffset; + scalar_t const *myG = g + batch * maxGLen * hiddenSize + u * hiddenSize + hOffset; + scalar_t *mySum = sum + myBatchOffset + t * strideF + u * hiddenSize + hOffset; + uint8_t *myMask = mask + myBatchOffset + t * strideF + u * hiddenSize + hOffset; + + // The following code is only needed for dropout. We try to bypass them as much as possible. + auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) + : std::make_tuple(static_cast(0), static_cast(0)); + uint64_t tid = + masked ? (static_cast(blockIdx.z) * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.x + + threadIdx.x + : 0; + Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); + scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; + bool dropoutMask[U]; + + if (t < myFLen and u < myGLen and hOffset + h < hiddenSize) { + // register buffers for tiled input reuse + scalar_t fBuffer[tileF], gBuffer[tileG]; + for (int i = 0; i < tileF; ++i) { + if (t + i < myFLen) fBuffer[i] = myF[i * hiddenSize + h]; + } + for (int j = 0; j < tileG; ++j) { + if (u + j < myGLen) gBuffer[j] = myG[j * hiddenSize + h]; + } +#pragma unroll + for (int i = 0; i < tileF; ++i) { + if (t + i < myFLen) { +#pragma unroll + for (int j = 0; j < tileG; ++j) { + int idx = i * tileG + j; + if (masked and dropout and idx % U == 0) { + // For performance, generate 4 random numbers in one shot + // auto rand4 = curand_uniform4(&state); + auto rand4 = uniform4(ph()); + dropoutMask[0] = rand4.x < p; + dropoutMask[1] = rand4.y < p; + dropoutMask[2] = rand4.z < p; + dropoutMask[3] = rand4.w < p; + } + + if (u + j < myGLen) { + scalar_t out = fBuffer[i] + gBuffer[j]; + if (masked) { + // Apply ReLU here when relu is True + bool localMask = relu ? (out > 0) : 1; + localMask = dropout ? localMask & dropoutMask[idx % U] : localMask; + out = dropout ? out * localMask * scale : out * localMask; + myMask[i * strideF + j * hiddenSize + h] = static_cast(localMask); } + mySum[i * strideF + j * hiddenSize + h] = out; + } else if (packOutput == false and u + j < maxGLen) + mySum[i * strideF + j * hiddenSize + h] = -1; } + } else if (packOutput == false and t + i < maxFLen) { +// Again need to write finite data to don't-care region +#pragma unroll + for (int j = 0; j < tileG; ++j) { + if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1; + } + } } - else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){ - // Only need to ensure the finity in normal mode - #pragma unroll - for (int i = 0; i < tileF; ++i){ - if (t + i < maxFLen){ - #pragma unroll - for (int j = 0; j < tileG; ++j){ - if (u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } - } + } else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset + h < hiddenSize) { +// Only need to ensure the finity in normal mode +#pragma unroll + for (int i = 0; i < tileF; ++i) { + if (t + i < maxFLen) { +#pragma unroll + for (int j = 0; j < tileG; ++j) { + if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1; } + } } + } } /* @@ -363,524 +320,375 @@ Bwd operation (reduction) on one input tensor. Since the operation performed for tensors are exactly the same, only one kernel is needed, and the different indexing offsets and strides are handled by OffsetCalBwd. -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template -__device__ void transducer_joint_single_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim, // whether bwd on the faster moving dimension (u) - float scale, - scalar_t *inGrad, - int yBlockOffset=0) { - - - const int batch = blockIdx.z; - // For the second input tensor, this offset need to be subtracted because the first yBlockOffset - // sets of thread blocks are for the first input tensor. - const int x = blockIdx.y-yBlockOffset; - const int hOffset = blockIdx.x*C10_WARP_SIZE; - const int wid = threadIdx.y; - const int lid = threadIdx.x; - const int numWarp = blockDim.y; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, - bwdFasterDim); - const auto maxXLen = offsetCal.getMaxXLen(); - const auto myXLen = offsetCal.getMyXLen(); - const auto myYLen = offsetCal.getMyYLen(); - scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset; - - if (x < myXLen){ - - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideX = offsetCal.getStrideX(); - const auto strideY = offsetCal.getStrideY(); - const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; - const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr; - - // Each warp reduces numYPerWarp "y" first - acc_t warpSum = 0; - auto numYPerWarp = (myYLen+numWarp-1)/numWarp; - #pragma unroll - for (int warpY = 0; warpY < numYPerWarp; ++warpY){ - auto y = wid*numYPerWarp + warpY; - if (y < myYLen and (hOffset+lid) < hiddenSize) - if (masked) - warpSum += static_cast(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale; - else - warpSum += myGrad[y*strideY + lid]; - } - - // transpose partial sum in SMEM and reduce further using warpReduce - smem[lid*numWarp + wid] = warpSum; - __syncthreads(); - auto sum = smem[wid*C10_WARP_SIZE + lid]; - sum = warpReduce(sum, numWarp); - - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // example of 4 warps (a, b, c, d) with 8 threads per warp - // Each warp need 8 / 4 = 2 threads to write the results. - if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){ - if (lid % numWarp == 0){ - myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum; - } - } +__device__ void transducer_joint_single_backward(const scalar_t *grad, const uint8_t *mask, const int *fLen, + const int *gLen, const int64_t *batchOffset, int64_t maxFLen, + int64_t maxGLen, int64_t hiddenSize, bool packOutput, + bool bwdFasterDim, // whether bwd on the faster moving dimension (u) + float scale, scalar_t *inGrad, int yBlockOffset = 0) { + const int batch = blockIdx.z; + // For the second input tensor, this offset need to be subtracted because the first yBlockOffset + // sets of thread blocks are for the first input tensor. + const int x = blockIdx.y - yBlockOffset; + const int hOffset = blockIdx.x * C10_WARP_SIZE; + const int wid = threadIdx.y; + const int lid = threadIdx.x; + const int numWarp = blockDim.y; + extern __shared__ char smem8[]; + auto smem = reinterpret_cast(smem8); + + OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); + const auto maxXLen = offsetCal.getMaxXLen(); + const auto myXLen = offsetCal.getMyXLen(); + const auto myYLen = offsetCal.getMyYLen(); + scalar_t *myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset; + + if (x < myXLen) { + const auto myBatchOffset = offsetCal.getBatchOffset(); + const auto strideX = offsetCal.getStrideX(); + const auto strideY = offsetCal.getStrideY(); + const scalar_t *myGrad = grad + myBatchOffset + x * strideX + hOffset; + const uint8_t *myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr; + + // Each warp reduces numYPerWarp "y" first + acc_t warpSum = 0; + auto numYPerWarp = (myYLen + numWarp - 1) / numWarp; +#pragma unroll + for (int warpY = 0; warpY < numYPerWarp; ++warpY) { + auto y = wid * numYPerWarp + warpY; + if (y < myYLen and (hOffset + lid) < hiddenSize) + if (masked) + warpSum += static_cast(myGrad[y * strideY + lid]) * myMask[y * strideY + lid] * scale; + else + warpSum += myGrad[y * strideY + lid]; } - else if (wid == 0 and hOffset + lid < hiddenSize){ - // Need to ensure the grad is zero for don't care region - myInGrad[lid] = 0; + + // transpose partial sum in SMEM and reduce further using warpReduce + smem[lid * numWarp + wid] = warpSum; + __syncthreads(); + auto sum = smem[wid * C10_WARP_SIZE + lid]; + sum = warpReduce(sum, numWarp); + + // a a b b c c d d + // a a b b c c d d + // a a b b c c d d + // a a b b c c d d + // example of 4 warps (a, b, c, d) with 8 threads per warp + // Each warp need 8 / 4 = 2 threads to write the results. + if (hOffset + wid * C10_WARP_SIZE / numWarp + lid / numWarp < hiddenSize) { + if (lid % numWarp == 0) { + myInGrad[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = sum; + } } + } else if (wid == 0 and hOffset + lid < hiddenSize) { + // Need to ensure the grad is zero for don't care region + myInGrad[lid] = 0; + } } /* Actual bwd (reduction) kernel get launched. -Call transducer_joint_single_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +Call transducer_joint_single_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template -__global__ void transducer_joint_combined_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - float scale, - scalar_t *fGrad, - scalar_t *gGrad) { - if (blockIdx.y < maxFLen){ - transducer_joint_single_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - false, - scale, - fGrad); - } - else{ - transducer_joint_single_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - true, - scale, - gGrad, - maxFLen); - } +__global__ void transducer_joint_combined_backward(const scalar_t *grad, const uint8_t *mask, const int *fLen, + const int *gLen, const int64_t *batchOffset, int64_t maxFLen, + int64_t maxGLen, int64_t hiddenSize, bool packOutput, float scale, + scalar_t *fGrad, scalar_t *gGrad) { + if (blockIdx.y < maxFLen) { + transducer_joint_single_backward( + grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad); + } else { + transducer_joint_single_backward( + grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen); + } } /* Vectorized version of transducer_joint_single_backward Doing exact same operation as transducer_joint_single_backward except the load and store are vectorized. -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template -__device__ void transducer_joint_single_vec_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim, - float scale, - scalar_t *inGrad, - int yBlockOffset=0){ - - const int batch = blockIdx.z; - const int x = blockIdx.y - yBlockOffset; - const int hOffset = blockIdx.x*C10_WARP_SIZE*V; - const int wid = threadIdx.y; - const int lid = threadIdx.x; - const int numWarp = blockDim.y; - - // Figure out the vectorization type for mask - using mvec_t = mvec_type; - - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, - bwdFasterDim); - const auto maxXLen = offsetCal.getMaxXLen(); - const auto myXLen = offsetCal.getMyXLen(); - const auto myYLen = offsetCal.getMyYLen(); - scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - acc_t warpSum[V]; - scalar_t inBuffer[V]; - uint8_t maskBuffer[V]; - scalar_t outBuffer[V]; - auto myInGradVec = reinterpret_cast(myInGrad); - auto outBufferVec = reinterpret_cast(outBuffer); - - if (x < myXLen){ - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideX = offsetCal.getStrideX(); - const auto strideY = offsetCal.getStrideY(); - const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; - const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset - :nullptr; - - for (int i = 0; i < V; ++i) - warpSum[i] = 0; - - // Each warp reduces numYPerWarp "y" first - auto numYPerWarp = (myYLen+numWarp-1)/numWarp; - for (int warpY = 0; warpY < numYPerWarp; ++warpY){ - auto y = wid*numYPerWarp + warpY; - auto myGradVec = reinterpret_cast(myGrad + y*strideY); - auto myMaskVec = masked ? reinterpret_cast(myMask + y*strideY) - : nullptr; - auto inBufferVec = reinterpret_cast(inBuffer); - auto maskBufferVec = reinterpret_cast(maskBuffer); - if (hOffset + lid*V < hiddenSize and y < myYLen){ - *inBufferVec = myGradVec[lid]; // vectorized load - if (masked){ - *maskBufferVec = myMaskVec[lid]; - #pragma unroll - for (int i = 0; i < V; ++i) - warpSum[i] += static_cast(inBuffer[i]) * maskBuffer[i] * scale; - } - else{ - #pragma unroll - for (int i = 0; i < V; ++i) - warpSum[i] += inBuffer[i]; - } - } - } - - // transpose partial sum in SMEM and reduce further using warpReduce - for (int i = 0; i < V; ++i){ - smem[lid*numWarp + wid] = warpSum[i]; - __syncthreads(); - auto sum = smem[wid*C10_WARP_SIZE + lid]; - - if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){ - sum = warpReduce(sum, numWarp); - if (lid % numWarp == 0){ - outBuffer[i] = sum; - } - } - __syncthreads(); +__device__ void transducer_joint_single_vec_backward(const scalar_t *grad, const uint8_t *mask, const int *fLen, + const int *gLen, const int64_t *batchOffset, int64_t maxFLen, + int64_t maxGLen, int64_t hiddenSize, bool packOutput, + bool bwdFasterDim, float scale, scalar_t *inGrad, + int yBlockOffset = 0) { + const int batch = blockIdx.z; + const int x = blockIdx.y - yBlockOffset; + const int hOffset = blockIdx.x * C10_WARP_SIZE * V; + const int wid = threadIdx.y; + const int lid = threadIdx.x; + const int numWarp = blockDim.y; + + // Figure out the vectorization type for mask + using mvec_t = mvec_type; + + OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); + const auto maxXLen = offsetCal.getMaxXLen(); + const auto myXLen = offsetCal.getMyXLen(); + const auto myYLen = offsetCal.getMyYLen(); + scalar_t *myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset; + extern __shared__ char smem8[]; + auto smem = reinterpret_cast(smem8); + + acc_t warpSum[V]; + scalar_t inBuffer[V]; + uint8_t maskBuffer[V]; + scalar_t outBuffer[V]; + auto myInGradVec = reinterpret_cast(myInGrad); + auto outBufferVec = reinterpret_cast(outBuffer); + + if (x < myXLen) { + const auto myBatchOffset = offsetCal.getBatchOffset(); + const auto strideX = offsetCal.getStrideX(); + const auto strideY = offsetCal.getStrideY(); + const scalar_t *myGrad = grad + myBatchOffset + x * strideX + hOffset; + const uint8_t *myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr; + + for (int i = 0; i < V; ++i) warpSum[i] = 0; + + // Each warp reduces numYPerWarp "y" first + auto numYPerWarp = (myYLen + numWarp - 1) / numWarp; + for (int warpY = 0; warpY < numYPerWarp; ++warpY) { + auto y = wid * numYPerWarp + warpY; + auto myGradVec = reinterpret_cast(myGrad + y * strideY); + auto myMaskVec = masked ? reinterpret_cast(myMask + y * strideY) : nullptr; + auto inBufferVec = reinterpret_cast(inBuffer); + auto maskBufferVec = reinterpret_cast(maskBuffer); + if (hOffset + lid * V < hiddenSize and y < myYLen) { + *inBufferVec = myGradVec[lid]; // vectorized load + if (masked) { + *maskBufferVec = myMaskVec[lid]; +#pragma unroll + for (int i = 0; i < V; ++i) warpSum[i] += static_cast(inBuffer[i]) * maskBuffer[i] * scale; + } else { +#pragma unroll + for (int i = 0; i < V; ++i) warpSum[i] += inBuffer[i]; } - - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // example of 4 warps (a, b, c, d) with 8 threads per warp - // Each warp need 8 / 4 = 2 threads to write the results. - if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize) - myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec; + } } - else if (wid == 0 and hOffset + lid*V < hiddenSize){ - // Need to ensure the grad is zero for don't care region - myInGradVec[lid] = 0; + + // transpose partial sum in SMEM and reduce further using warpReduce + for (int i = 0; i < V; ++i) { + smem[lid * numWarp + wid] = warpSum[i]; + __syncthreads(); + auto sum = smem[wid * C10_WARP_SIZE + lid]; + + if (hOffset + (wid * C10_WARP_SIZE / numWarp) * V < hiddenSize) { + sum = warpReduce(sum, numWarp); + if (lid % numWarp == 0) { + outBuffer[i] = sum; + } + } + __syncthreads(); } + + // a a b b c c d d + // a a b b c c d d + // a a b b c c d d + // a a b b c c d d + // example of 4 warps (a, b, c, d) with 8 threads per warp + // Each warp need 8 / 4 = 2 threads to write the results. + if (lid % numWarp == 0 and hOffset + (wid * C10_WARP_SIZE / numWarp + lid / numWarp) * V < hiddenSize) + myInGradVec[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = *outBufferVec; + } else if (wid == 0 and hOffset + lid * V < hiddenSize) { + // Need to ensure the grad is zero for don't care region + myInGradVec[lid] = 0; + } } /* Vecotrized version of transducer_joint_combined_backward -Call transducer_joint_single_vec_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +Call transducer_joint_single_vec_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template -__global__ void transducer_joint_combined_vec_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - float scale, - scalar_t *fGrad, - scalar_t *gGrad) { - if (blockIdx.y < maxFLen){ - transducer_joint_single_vec_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - false, - scale, - fGrad); - } - else{ - transducer_joint_single_vec_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - true, - scale, - gGrad, - maxFLen); - } +__global__ void transducer_joint_combined_vec_backward(const scalar_t *grad, const uint8_t *mask, const int *fLen, + const int *gLen, const int64_t *batchOffset, int64_t maxFLen, + int64_t maxGLen, int64_t hiddenSize, bool packOutput, + float scale, scalar_t *fGrad, scalar_t *gGrad) { + if (blockIdx.y < maxFLen) { + transducer_joint_single_vec_backward( + grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad); + } else { + transducer_joint_single_vec_backward( + grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen); + } } - - - -std::vector transducer_joint_cuda_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize){ - - - auto tensorOpt = f.options(); - auto dtype = f.scalar_type(); - const auto batchSize = f.size(0); - const auto maxFLen = f.size(1); - const auto maxGLen = g.size(1); - const auto hiddenSize = f.size(2); - bool masked = dropout or relu; - - int64_t *batchOffsetPtr = nullptr; - torch::Tensor sum, mask; - auto maskOpt = tensorOpt.dtype(torch::kUInt8); - if (!packOutput){ - sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); - batchOffsetPtr = nullptr; - if (masked) - mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); - } - else{ - sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); - batchOffsetPtr = batchOffset.data_ptr(); - if (masked) - mask = torch::empty({packedBatch, hiddenSize}, maskOpt); - } - uint8_t *maskPtr = masked ? mask.data_ptr() : nullptr; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); - // Simple heuristics - const int numThread = std::min(128, (static_cast(hiddenSize)+C10_WARP_SIZE-1) - / C10_WARP_SIZE * C10_WARP_SIZE); - - if (opt == 0){ - // vanilla kernel - const int threads = numThread; - const dim3 blocks(maxGLen, maxFLen, batchSize); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - transducer_joint_forward - <<>>( - f.data_ptr(), - g.data_ptr(), - fLen.data_ptr(), - gLen.data_ptr(), - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - sum.data_ptr()); - })); +std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, + int64_t packedBatch, int opt, bool packOutput, bool relu, + bool dropout, float dropoutProb, int tileSize) { + auto tensorOpt = f.options(); + auto dtype = f.scalar_type(); + const auto batchSize = f.size(0); + const auto maxFLen = f.size(1); + const auto maxGLen = g.size(1); + const auto hiddenSize = f.size(2); + bool masked = dropout or relu; + + int64_t *batchOffsetPtr = nullptr; + torch::Tensor sum, mask; + auto maskOpt = tensorOpt.dtype(torch::kUInt8); + if (!packOutput) { + sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); + batchOffsetPtr = nullptr; + if (masked) mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); + } else { + sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); + batchOffsetPtr = batchOffset.data_ptr(); + if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt); + } + uint8_t *maskPtr = masked ? mask.data_ptr() : nullptr; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); + // Simple heuristics + const int numThread = + std::min(128, (static_cast(hiddenSize) + C10_WARP_SIZE - 1) / C10_WARP_SIZE * C10_WARP_SIZE); + + if (opt == 0) { + // vanilla kernel + const int threads = numThread; + const dim3 blocks(maxGLen, maxFLen, batchSize); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dtype, "transducer_joint_forward", ([&] { + transducer_joint_forward<<>>( + f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, sum.data_ptr()); + })); + } + if (opt == 1) { + // tiled version. For simplicity, assume tileF == tileG, even though the kernel can + // support more general cases. + const int threads = numThread; + const int hiddenPerBlock = numThread; + const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; + const dim3 blocks((maxGLen + tileSize - 1) / tileSize * hiddenBlock, (maxFLen + tileSize - 1) / tileSize, + batchSize); + + TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, "Expected tileSize to be in [1, 2, 4], but got ", + tileSize); + + at::PhiloxCudaState rng_engine_inputs; + if (masked) { + // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler + // for non-masked calls. + // Therefore no need to initialize. + c10::optional gen_; + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, + // each thread processes tileF * tileG output elements. + int64_t counterOffset = tileSize * tileSize; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counterOffset); + } } - if (opt == 1){ - // tiled version. For simplicity, assume tileF == tileG, even though the kernel can - // support more general cases. - const int threads = numThread; - const int hiddenPerBlock = numThread; - const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, - (maxFLen+tileSize-1)/tileSize, - batchSize); - - TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, - "Expected tileSize to be in [1, 2, 4], but got ", tileSize); - - at::PhiloxCudaState rng_engine_inputs; - if (masked){ - // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler - // for non-masked calls. - // Therefore no need to initialize. - c10::optional gen_; - auto gen = at::get_generator_or_default(gen_, - at::cuda::detail::getDefaultCUDAGenerator()); - // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, - // each thread processes tileF * tileG output elements. - int64_t counterOffset = tileSize * tileSize; - { - std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_cuda_state(counterOffset); - } - } - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, - int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, - at::PhiloxCudaState, scalar_t*, uint8_t*); - if (masked){ - switch (tileSize){ - case 2: - kernel = &transducer_joint_tiled_forward; - break; - case 4: - kernel = &transducer_joint_tiled_forward; - break; - } + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dtype, "transducer_joint_forward", ([&] { + void (*kernel)(const scalar_t *, const scalar_t *, const int *, const int *, const int64_t *, int64_t, + int64_t, int64_t, int64_t, bool, bool, bool, float, at::PhiloxCudaState, scalar_t *, + uint8_t *); + if (masked) { + switch (tileSize) { + case 2: + kernel = &transducer_joint_tiled_forward; + break; + case 4: + kernel = &transducer_joint_tiled_forward; + break; } - else{ - switch (tileSize){ - case 1: - kernel = &transducer_joint_tiled_forward; - break; - case 2: - kernel = &transducer_joint_tiled_forward; - break; - case 4: - kernel = &transducer_joint_tiled_forward; - break; - } + } else { + switch (tileSize) { + case 1: + kernel = &transducer_joint_tiled_forward; + break; + case 2: + kernel = &transducer_joint_tiled_forward; + break; + case 4: + kernel = &transducer_joint_tiled_forward; + break; } - - kernel<<>>( - f.data_ptr(), - g.data_ptr(), - fLen.data_ptr(), - gLen.data_ptr(), - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - hiddenPerBlock, - packOutput, - relu, - dropout, - 1.0f - dropoutProb, - rng_engine_inputs, - sum.data_ptr(), - maskPtr); - })); - } - - C10_CUDA_CHECK(cudaGetLastError()); - if (masked) - return {sum, mask}; - else - return {sum}; + } + + kernel<<>>(f.data_ptr(), g.data_ptr(), fLen.data_ptr(), + gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, + hiddenPerBlock, packOutput, relu, dropout, 1.0f - dropoutProb, + rng_engine_inputs, sum.data_ptr(), maskPtr); + })); + } + + C10_CUDA_CHECK(cudaGetLastError()); + if (masked) + return {sum, mask}; + else + return {sum}; } -std::vector transducer_joint_cuda_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale){ - - auto grad = in[0]; - bool masked = (in.size() == 2); - uint8_t *maskPtr = masked ? in[1].data_ptr() : nullptr; - - auto tensorOpt = grad.options(); - auto dtype = grad.scalar_type(); - const int batchSize = fLen.size(0); - const int hiddenSize = grad.size(-1); - - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; - - torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); - torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); - - int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); - - // The number "y" I would like each thread to work on - const int workPerThread = 32; - // Since the bwd for f and g have the same thread block size, we need to use the max of the two. - int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread); - // Would like to have at least 2 warps - numWarp = std::max(2, numWarp); - // cap on the maximum number of warps allowed - numWarp = std::min(maxNumWarp, numWarp); - - // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape - // numWarp x warpSize - const int smemSize = numWarp * C10_WARP_SIZE; - const dim3 threads(C10_WARP_SIZE, numWarp, 1); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] { +std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, + torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale) { + auto grad = in[0]; + bool masked = (in.size() == 2); + uint8_t *maskPtr = masked ? in[1].data_ptr() : nullptr; + + auto tensorOpt = grad.options(); + auto dtype = grad.scalar_type(); + const int batchSize = fLen.size(0); + const int hiddenSize = grad.size(-1); + + const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); + const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; + + torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); + torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); + + int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); + + // The number "y" I would like each thread to work on + const int workPerThread = 32; + // Since the bwd for f and g have the same thread block size, we need to use the max of the two. + int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread - 1) / workPerThread); + // Would like to have at least 2 warps + numWarp = std::max(2, numWarp); + // cap on the maximum number of warps allowed + numWarp = std::min(maxNumWarp, numWarp); + + // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape + // numWarp x warpSize + const int smemSize = numWarp * C10_WARP_SIZE; + const dim3 threads(C10_WARP_SIZE, numWarp, 1); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dtype, "transducer_joint_cuda_backward_kernel", ([&] { auto gradPtr = grad.data_ptr(); auto fLenPtr = fLen.data_ptr(); - auto gLenPtr = gLen.data_ptr(); + auto gLenPtr = gLen.data_ptr(); auto fGradPtr = fGrad.data_ptr(); auto gGradPtr = gGrad.data_ptr(); @@ -892,88 +700,41 @@ std::vector transducer_joint_cuda_backward( constexpr int vecAlignment = std::alignment_of::value; // if all input and output tensors meet the alignment requirement - bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) - and (reinterpret_cast(fGradPtr) % vecAlignment == 0) - and (reinterpret_cast(gGradPtr) % vecAlignment == 0); - - if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){ - // If vectorization helps and the alignment requirement is met, use the vectorized - // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. - const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), - maxFLen+maxGLen, - batchSize); - if (masked){ - transducer_joint_combined_vec_backward - - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - else{ - transducer_joint_combined_vec_backward - - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - } - else{ - const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, - maxFLen + maxGLen, batchSize); - if (masked){ - transducer_joint_combined_backward - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - else{ - transducer_joint_combined_backward - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } + bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) and + (reinterpret_cast(fGradPtr) % vecAlignment == 0) and + (reinterpret_cast(gGradPtr) % vecAlignment == 0); + + if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { + // If vectorization helps and the alignment requirement is met, use the vectorized + // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. + const dim3 blocks((hiddenSize + C10_WARP_SIZE * vectFactor - 1) / (C10_WARP_SIZE * vectFactor), + maxFLen + maxGLen, batchSize); + if (masked) { + transducer_joint_combined_vec_backward + <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, + maxFLen, maxGLen, hiddenSize, packOutput, scale, + fGradPtr, gGradPtr); + } else { + transducer_joint_combined_vec_backward + <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, + maxFLen, maxGLen, hiddenSize, packOutput, scale, + fGradPtr, gGradPtr); + } + } else { + const dim3 blocks((hiddenSize + C10_WARP_SIZE - 1) / C10_WARP_SIZE, maxFLen + maxGLen, batchSize); + if (masked) { + transducer_joint_combined_backward + <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, + maxFLen, maxGLen, hiddenSize, packOutput, scale, + fGradPtr, gGradPtr); + } else { + transducer_joint_combined_backward + <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, + maxFLen, maxGLen, hiddenSize, packOutput, scale, + fGradPtr, gGradPtr); + } } - })); + })); - return {fGrad, gGrad}; + return {fGrad, gGrad}; } diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index 91c956239..4c124edac 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -1,109 +1,53 @@ #include + #include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector transducer_loss_cuda_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput); - -torch::Tensor transducer_loss_cuda_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput); - - -std::vector transducer_loss_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor fLen, - torch::Tensor yLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput - ) { - - CHECK_INPUT(x); - CHECK_INPUT(label); - CHECK_INPUT(fLen); - CHECK_INPUT(yLen); - if (packedInput) - CHECK_INPUT(batchOffset); - return transducer_loss_cuda_forward( - x, - label, - fLen, - yLen, - batchOffset, - maxFLen, - blankIdx, - opt, - packedInput); +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, + torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, + int blankIdx, int opt, bool packedInput); + +torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, + torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, + torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, + int opt, bool fuseSoftmaxBackward, bool packedInput); + +std::vector transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen, + torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen, + int blankIdx, int opt, bool packedInput) { + CHECK_INPUT(x); + CHECK_INPUT(label); + CHECK_INPUT(fLen); + CHECK_INPUT(yLen); + if (packedInput) CHECK_INPUT(batchOffset); + return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput); } -torch::Tensor transducer_loss_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor fLen, - torch::Tensor yLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput){ - - CHECK_INPUT(x); - CHECK_INPUT(label); - CHECK_INPUT(lossGrad); - CHECK_INPUT(alpha); - CHECK_INPUT(beta); - CHECK_INPUT(fLen); - CHECK_INPUT(yLen); - if (packedInput) - CHECK_INPUT(batchOffset); - - return transducer_loss_cuda_backward( - x, - lossGrad, - alpha, - beta, - fLen, - yLen, - label, - batchOffset, - maxFLen, - blankIdx, - opt, - fuseSoftmaxBackward, - packedInput); +torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, + torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label, + torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, + bool fuseSoftmaxBackward, bool packedInput) { + CHECK_INPUT(x); + CHECK_INPUT(label); + CHECK_INPUT(lossGrad); + CHECK_INPUT(alpha); + CHECK_INPUT(beta); + CHECK_INPUT(fLen); + CHECK_INPUT(yLen); + if (packedInput) CHECK_INPUT(batchOffset); + + return transducer_loss_cuda_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, maxFLen, blankIdx, opt, + fuseSoftmaxBackward, packedInput); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", py::call_guard()); - m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", py::call_guard()); + m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", + py::call_guard()); + m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", + py::call_guard()); } diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu old mode 100755 new mode 100644 index 295e14b3f..f6f7e4ca0 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -1,136 +1,108 @@ -#include - -#include -#include - -#include #include #include #include +#include +#include +#include -template +#include + +template __device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) { - // standard log-sum-exp trick is used here to provide better numerical stability - return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b)); + // standard log-sum-exp trick is used here to provide better numerical stability + return (a >= b) ? a + std::log1p(exp(b - a)) : b + std::log1p(exp(a - b)); } // Vanilla transducer loss function (i.e. forward-backward algorithm) -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer -// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. +// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into +// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template -__global__ void transducer_loss_forward( - const scalar_t* x, - const int* label, - const int* audLen, - const int* txtLen, - const int64_t* batchOffset, - int64_t dictSize, // 64-bit indexing for data tensor - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - acc_t* alpha, - acc_t* beta, - scalar_t* loss) { - - const int batch = blockIdx.y; - const int tid = threadIdx.x; - const auto myFLen = audLen[batch]; - // Note that start of the sentence is added as 1 here - const auto myGLen = txtLen[batch] + 1; - const auto myLabel = label + batch * (maxGLen-1); - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; - int u = tid; - - if (blockIdx.x == 0){ - // alpha path - acc_t* myAlpha = alpha + batch*maxFLen*maxGLen; - if (u == 0) - myAlpha[0] = 0; - __syncthreads(); - - for (int64_t step = 1; step < myFLen+myGLen-1; ++step){ - // Move along the diagonal wavefront to leverage available parallelism - for (u = tid; u < myGLen; u += blockDim.x){ - int64_t t = step - u; - if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(16) in [1] - if (u == 0){ - // alpha(t, u) = alpha(t-1, u) * null(t-1, u) - myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] - + myX[((t-1)*myStrideT) * dictSize + blankIdx]; - } - else if (t == 0){ - // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1) - myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]]; - } - else{ - // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1) - acc_t current = myAlpha[(t-1)*maxGLen + u] - + myX[((t-1)*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myAlpha[t*maxGLen + u - 1] - + myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]]; - myAlpha[t*maxGLen + u] = logSumExp(next, current); - } - } - } - __syncthreads(); +__global__ void transducer_loss_forward(const scalar_t* x, const int* label, const int* audLen, const int* txtLen, + const int64_t* batchOffset, + int64_t dictSize, // 64-bit indexing for data tensor + int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, + acc_t* alpha, acc_t* beta, scalar_t* loss) { + const int batch = blockIdx.y; + const int tid = threadIdx.x; + const auto myFLen = audLen[batch]; + // Note that start of the sentence is added as 1 here + const auto myGLen = txtLen[batch] + 1; + const auto myLabel = label + batch * (maxGLen - 1); + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + const scalar_t* myX = x + myBatchOffset * dictSize; + int u = tid; + + if (blockIdx.x == 0) { + // alpha path + acc_t* myAlpha = alpha + batch * maxFLen * maxGLen; + if (u == 0) myAlpha[0] = 0; + __syncthreads(); + + for (int64_t step = 1; step < myFLen + myGLen - 1; ++step) { + // Move along the diagonal wavefront to leverage available parallelism + for (u = tid; u < myGLen; u += blockDim.x) { + int64_t t = step - u; + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + // Eq(16) in [1] + if (u == 0) { + // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + myAlpha[t * maxGLen + u] = myAlpha[(t - 1) * maxGLen] + myX[((t - 1) * myStrideT) * dictSize + blankIdx]; + } else if (t == 0) { + // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1) + myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]]; + } else { + // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1) + acc_t current = myAlpha[(t - 1) * maxGLen + u] + myX[((t - 1) * myStrideT + u) * dictSize + blankIdx]; + acc_t next = myAlpha[t * maxGLen + u - 1] + myX[(t * myStrideT + u - 1) * dictSize + myLabel[u - 1]]; + myAlpha[t * maxGLen + u] = logSumExp(next, current); + } } + } + __syncthreads(); } - else if (blockIdx.x == 1){ - // beta path - acc_t* myBeta = beta + batch*maxFLen*maxGLen; - if (u == 0){ - myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT - + myGLen - 1) * dictSize + blankIdx]; - } - __syncthreads(); - - for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){ - for (u = tid; u < myGLen; u += blockDim.x){ - int64_t t = step - u; - if (t >= 0 and t < myFLen and u >=0 and u < myGLen){ - // Eq(18) in [1] - if (u == myGLen - 1){ - // beta(t, u) = beta(t+1, u) * null(t, u) - myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] - + myX[(t*myStrideT + u) * dictSize + blankIdx]; - } - else if (t == myFLen - 1){ - // beta(t, u) = beta(t, u+1) * y(t, u) - myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] - + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; - } - else{ - // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u) - acc_t current = myBeta[(t+1)*maxGLen + u] - + myX[(t*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myBeta[t*maxGLen + u + 1] - + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; - myBeta[t*maxGLen + u] = logSumExp(next, current); - } - } - } - __syncthreads(); + } else if (blockIdx.x == 1) { + // beta path + acc_t* myBeta = beta + batch * maxFLen * maxGLen; + if (u == 0) { + myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; + } + __syncthreads(); + + for (int64_t step = myFLen + myGLen - 3; step >= 0; --step) { + for (u = tid; u < myGLen; u += blockDim.x) { + int64_t t = step - u; + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + // Eq(18) in [1] + if (u == myGLen - 1) { + // beta(t, u) = beta(t+1, u) * null(t, u) + myBeta[t * maxGLen + u] = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx]; + } else if (t == myFLen - 1) { + // beta(t, u) = beta(t, u+1) * y(t, u) + myBeta[t * maxGLen + u] = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]]; + } else { + // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u) + acc_t current = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx]; + acc_t next = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]]; + myBeta[t * maxGLen + u] = logSumExp(next, current); + } } - if (tid == 0) - loss[batch] = -myBeta[0]; + } + __syncthreads(); } - + if (tid == 0) loss[batch] = -myBeta[0]; + } } // transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization. @@ -140,183 +112,159 @@ __global__ void transducer_loss_forward( // For simplicity, this kernel currently only supports U <= maxThread, which should be the common // case. For cases where U > maxThread, the vanilla kernel is used as a fallback option. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer // Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into +// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template -__global__ void transducer_loss_batch_load_forward( - const scalar_t* x, - const int* label, - const int* audLen, - const int* txtLen, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - acc_t* alpha, - acc_t* beta, - scalar_t* loss) { - - const int batch = blockIdx.y; - int u = threadIdx.x; - const auto myFLen = audLen[batch]; - const auto myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; - scalar_t next[batchLdSize], current[batchLdSize]; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - if (blockIdx.x == 0){ - // alpha path - acc_t* myAlpha = alpha + batch*maxFLen*maxGLen; - // two SMEM regions for double buffering read and write data to avoid data race - acc_t * const sharedAlpha[2] = {smem, smem+maxGLen}; - - sharedAlpha[0][u] = 0; - __syncthreads(); - - if (u == 0) - myAlpha[0] = 0; - - auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1]; - // register used to pass value to the next step for the same thread - acc_t prvStepAlpha = 0; - for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){ - // Move along the diagonal wavefront to leverage available parallelism - // Batch loading X through loop unrolling - #pragma unroll - for (int i = 0; i < batchLdSize; ++i){ - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - if (u == 0){ - current[i] = myX[currentId]; - } - else if (t == 0){ - next[i] = myX[nextId]; - } - else{ - current[i] = myX[currentId]; - next[i] = myX[nextId]; - } - } - } - } - // main computing loop - for (int i = 0; i < batchLdSize; ++i){ - // swap the pointer for double buffering - auto sharedAlphaRd = sharedAlpha[(step+i-1)%2]; - auto sharedAlphaWr = sharedAlpha[(step+i)%2]; - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(16) in [1] - if (u == 0) - prvStepAlpha = prvStepAlpha+current[i]; - else if (t == 0) - prvStepAlpha = sharedAlphaRd[u-1]+next[i]; - else - prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1] - + next[i]); - sharedAlphaWr[u] = prvStepAlpha; - myAlpha[t*maxGLen + u] = prvStepAlpha; - } - } - __syncthreads(); +__global__ void transducer_loss_batch_load_forward(const scalar_t* x, const int* label, const int* audLen, + const int* txtLen, const int64_t* batchOffset, int64_t dictSize, + int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, + acc_t* alpha, acc_t* beta, scalar_t* loss) { + const int batch = blockIdx.y; + int u = threadIdx.x; + const auto myFLen = audLen[batch]; + const auto myGLen = txtLen[batch] + 1; + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + const scalar_t* myX = x + myBatchOffset * dictSize; + scalar_t next[batchLdSize], current[batchLdSize]; + extern __shared__ char smem8[]; + auto smem = reinterpret_cast(smem8); + + if (blockIdx.x == 0) { + // alpha path + acc_t* myAlpha = alpha + batch * maxFLen * maxGLen; + // two SMEM regions for double buffering read and write data to avoid data race + acc_t* const sharedAlpha[2] = {smem, smem + maxGLen}; + + sharedAlpha[0][u] = 0; + __syncthreads(); + + if (u == 0) myAlpha[0] = 0; + + auto myAlphaLabel = (u == 0) ? 0 : label[batch * (maxGLen - 1) + u - 1]; + // register used to pass value to the next step for the same thread + acc_t prvStepAlpha = 0; + for (int64_t step = 1; step < myFLen + myGLen - 1 + batchLdSize; step += batchLdSize) { +// Move along the diagonal wavefront to leverage available parallelism +// Batch loading X through loop unrolling +#pragma unroll + for (int i = 0; i < batchLdSize; ++i) { + if (step + i < myFLen + myGLen - 1) { + // index computing + int64_t t = step + i - u; + int64_t currentId = ((t - 1) * myStrideT + u) * dictSize + blankIdx; + int64_t nextId = (t * myStrideT + u - 1) * dictSize + myAlphaLabel; + // main loading loop + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + if (u == 0) { + current[i] = myX[currentId]; + } else if (t == 0) { + next[i] = myX[nextId]; + } else { + current[i] = myX[currentId]; + next[i] = myX[nextId]; } + } + } + } + // main computing loop + for (int i = 0; i < batchLdSize; ++i) { + // swap the pointer for double buffering + auto sharedAlphaRd = sharedAlpha[(step + i - 1) % 2]; + auto sharedAlphaWr = sharedAlpha[(step + i) % 2]; + if (step + i < myFLen + myGLen - 1) { + int64_t t = step + i - u; + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + // Eq(16) in [1] + if (u == 0) + prvStepAlpha = prvStepAlpha + current[i]; + else if (t == 0) + prvStepAlpha = sharedAlphaRd[u - 1] + next[i]; + else + prvStepAlpha = logSumExp(prvStepAlpha + current[i], sharedAlphaRd[u - 1] + next[i]); + sharedAlphaWr[u] = prvStepAlpha; + myAlpha[t * maxGLen + u] = prvStepAlpha; + } } - } - else if (blockIdx.x == 1){ - // beta path - acc_t* myBeta = beta + batch*maxFLen*maxGLen; - // two SMEM regions for double buffering read and write data to avoid data race - acc_t * const sharedBeta[2] = {smem, smem + maxGLen}; - sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx]; __syncthreads(); - - auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u]; - // register used to pass value to the next step for the same thread - acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx]; - if (u == 0) - myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta; - - for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){ - // Move along the diagonal wavefront to leverage available parallelism - // Batch loading X - #pragma unroll - for (int i = 0; i < batchLdSize; ++i){ - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - if (u == myGLen - 1){ - current[i] = myX[currentId]; - } - else if (t == myFLen - 1){ - next[i] = myX[nextId]; - } - else{ - current[i] = myX[currentId]; - next[i] = myX[nextId]; - } - } - } - } - // main computing loop - for (int i = 0; i < batchLdSize; ++i){ - // swap the pointer for double buffering - auto sharedBetaRd = sharedBeta[(step+i-1)%2]; - auto sharedBetaWr = sharedBeta[(step+i)%2]; - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(18) in [1] - if (u == myGLen - 1) - prvStepBeta = prvStepBeta+current[i]; - else if (t == myFLen - 1) - prvStepBeta = sharedBetaRd[u+1]+next[i]; - else - prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1] - + next[i]); - sharedBetaWr[u] = prvStepBeta; - myBeta[t*maxGLen + u] = prvStepBeta; - } - - } - __syncthreads(); + } + } + } else if (blockIdx.x == 1) { + // beta path + acc_t* myBeta = beta + batch * maxFLen * maxGLen; + // two SMEM regions for double buffering read and write data to avoid data race + acc_t* const sharedBeta[2] = {smem, smem + maxGLen}; + sharedBeta[0][u] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; + __syncthreads(); + + auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch * (maxGLen - 1) + u]; + // register used to pass value to the next step for the same thread + acc_t prvStepBeta = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; + if (u == 0) myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = prvStepBeta; + + for (int64_t step = 1; step < myFLen + myGLen - 1; step += batchLdSize) { +// Move along the diagonal wavefront to leverage available parallelism +// Batch loading X +#pragma unroll + for (int i = 0; i < batchLdSize; ++i) { + if (step + i < myFLen + myGLen - 1) { + // index computing + int64_t t = myFLen + myGLen - (step + i) - 2 - u; + int64_t currentId = (t * myStrideT + u) * dictSize + blankIdx; + int64_t nextId = (t * myStrideT + u) * dictSize + myBetaLabel; + // main loading loop + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + if (u == myGLen - 1) { + current[i] = myX[currentId]; + } else if (t == myFLen - 1) { + next[i] = myX[nextId]; + } else { + current[i] = myX[currentId]; + next[i] = myX[nextId]; } + } } - if (u == 0) - loss[batch] = -prvStepBeta; + } + // main computing loop + for (int i = 0; i < batchLdSize; ++i) { + // swap the pointer for double buffering + auto sharedBetaRd = sharedBeta[(step + i - 1) % 2]; + auto sharedBetaWr = sharedBeta[(step + i) % 2]; + if (step + i < myFLen + myGLen - 1) { + int64_t t = myFLen + myGLen - (step + i) - 2 - u; + if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { + // Eq(18) in [1] + if (u == myGLen - 1) + prvStepBeta = prvStepBeta + current[i]; + else if (t == myFLen - 1) + prvStepBeta = sharedBetaRd[u + 1] + next[i]; + else + prvStepBeta = logSumExp(prvStepBeta + current[i], sharedBetaRd[u + 1] + next[i]); + sharedBetaWr[u] = prvStepBeta; + myBeta[t * maxGLen + u] = prvStepBeta; + } + } + __syncthreads(); + } } - + if (u == 0) loss[batch] = -prvStepBeta; + } } // Vanilla transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, +// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, // hence only Eq(20) in [1] is implemented in this kernel. // Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time @@ -326,271 +274,213 @@ __global__ void transducer_loss_batch_load_forward( // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template -__global__ void transducer_loss_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int t = blockIdx.x; - const int batch = blockIdx.y; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - auto myX = x + (myBatchOffset + t*myStrideT)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; - auto myLabel = label + batch*(maxGLen-1); - - int64_t u = tid; - while (t < myFLen and u < myGLen){ - // Do the update - // loss = -ln(Pr(y*|x)) - acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - if (u != myGLen - 1) - myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] - + myX[u*dictSize + myLabel[u]]); - if (t == myFLen - 1 and u == myGLen - 1) - myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]); - else if (t != myFLen - 1) - myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] - + myX[u*dictSize + blankIdx]); - - u += blockDim.x; - } +__global__ void transducer_loss_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, + const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta, + const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx, + int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) { + const int tid = threadIdx.x; + const int t = blockIdx.x; + const int batch = blockIdx.y; + const int64_t myFLen = audLen[batch]; + const int64_t myGLen = txtLen[batch] + 1; + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + auto myX = x + (myBatchOffset + t * myStrideT) * dictSize; + auto myAlpha = alpha + batch * maxFLen * maxGLen; + auto myBeta = beta + batch * maxFLen * maxGLen; + auto myXGrad = xGrad + (myBatchOffset + t * myStrideT) * dictSize; + auto myLabel = label + batch * (maxGLen - 1); + + int64_t u = tid; + while (t < myFLen and u < myGLen) { + // Do the update + // loss = -ln(Pr(y*|x)) + acc_t grad = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; + if (u != myGLen - 1) + myXGrad[u * dictSize + myLabel[u]] = + -std::exp(grad + myBeta[t * maxGLen + u + 1] + myX[u * dictSize + myLabel[u]]); + if (t == myFLen - 1 and u == myGLen - 1) + myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myX[u * dictSize + blankIdx]); + else if (t != myFLen - 1) + myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myBeta[(t + 1) * maxGLen + u] + myX[u * dictSize + blankIdx]); + + u += blockDim.x; + } } // Fused transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. +// The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template -__global__ void transducer_loss_fused_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int u = blockIdx.x; - const int t = blockIdx.y; - const int batch = blockIdx.z; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - - __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - - if (t < myFLen and u < myGLen){ - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myLabel = label + batch*(maxGLen-1); - - // load and store shared variables in SMEM - if (tid == 0){ - commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - myBetaTU = myBeta[t*maxGLen + u]; - myBetaTUp1 = myBeta[t*maxGLen + u + 1]; - myBetaTp1U = myBeta[(t+1)*maxGLen + u]; - myLabelShared = myLabel[u]; - } +__global__ void transducer_loss_fused_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, + const int* txtLen, const int* label, const acc_t* alpha, + const acc_t* beta, const int64_t* batchOffset, int64_t dictSize, + int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, + scalar_t* xGrad) { + const int tid = threadIdx.x; + const int u = blockIdx.x; + const int t = blockIdx.y; + const int batch = blockIdx.z; + const int64_t myFLen = audLen[batch]; + const int64_t myGLen = txtLen[batch] + 1; + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + + __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; + auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize; + + if (t < myFLen and u < myGLen) { + auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize; + auto myAlpha = alpha + batch * maxFLen * maxGLen; + auto myBeta = beta + batch * maxFLen * maxGLen; + auto myLabel = label + batch * (maxGLen - 1); + + // load and store shared variables in SMEM + if (tid == 0) { + commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; + myBetaTU = myBeta[t * maxGLen + u]; + myBetaTUp1 = myBeta[t * maxGLen + u + 1]; + myBetaTp1U = myBeta[(t + 1) * maxGLen + u]; + myLabelShared = myLabel[u]; + } - __syncthreads(); + __syncthreads(); - for (int64_t h = tid; h < dictSize; h += blockDim.x){ - // Do the update - acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) - acc_t myGrad = std::exp(grad + myBetaTU); - if (u != myGLen - 1 and h == myLabelShared){ - myGrad -= std::exp(grad + myBetaTUp1); - } - else if (h == blankIdx){ - if (t == myFLen - 1 and u == myGLen - 1) - myGrad -= std::exp(grad); - else if (t != myFLen - 1) - myGrad -= std::exp(grad + myBetaTp1U); - } - myXGrad[h] = myGrad; - } + for (int64_t h = tid; h < dictSize; h += blockDim.x) { + // Do the update + acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) + acc_t myGrad = std::exp(grad + myBetaTU); + if (u != myGLen - 1 and h == myLabelShared) { + myGrad -= std::exp(grad + myBetaTUp1); + } else if (h == blankIdx) { + if (t == myFLen - 1 and u == myGLen - 1) + myGrad -= std::exp(grad); + else if (t != myFLen - 1) + myGrad -= std::exp(grad + myBetaTp1U); + } + myXGrad[h] = myGrad; } - else if (!packedInput){ - // In non-pack mode, need to make sure the gradients for don't-care regions are zero. - for (int64_t h = tid; h < dictSize; h += blockDim.x){ - myXGrad[h] = 0; - } + } else if (!packedInput) { + // In non-pack mode, need to make sure the gradients for don't-care regions are zero. + for (int64_t h = tid; h < dictSize; h += blockDim.x) { + myXGrad[h] = 0; } + } } - // Vectorized version of fused transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. +// The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template -__global__ void transducer_loss_fused_vec_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int u = blockIdx.x; - const int t = blockIdx.y; - const int batch = blockIdx.z; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - - __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myLabel = label + batch*(maxGLen-1); - - // Variabels for vectorization - scalar_t myXBuffer[V], myXGradBuffer[V]; - auto myXVec = reinterpret_cast(myX); - auto myXGradVec = reinterpret_cast(myXGrad); - auto myXBufferVec = reinterpret_cast(myXBuffer); - auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); - if (t < myFLen and u < myGLen){ - // load and store shared variables in SMEM - if (tid == 0){ - commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - myBetaTU = myBeta[t*maxGLen + u]; - if (t != myFLen - 1) - myBetaTp1U = myBeta[(t+1)*maxGLen + u]; - if (u != myGLen - 1){ - myBetaTUp1 = myBeta[t*maxGLen + u + 1]; - myLabelShared = myLabel[u]; - } - } - - __syncthreads(); - - #pragma unroll - for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ - // Load myX in a vector form - *myXBufferVec = myXVec[h0/V]; - // Do the update for a vector of input - #pragma unroll - for (int i = 0; i < V; ++i){ - auto h = h0 + i; - acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x)) - acc_t myGrad = std::exp(grad + myBetaTU); - if (u != myGLen - 1 and h == myLabelShared){ - myGrad -= std::exp(grad + myBetaTUp1); - } - else if (h == blankIdx){ - if (t == myFLen - 1 and u == myGLen - 1) - myGrad -= std::exp(grad); - else if (t != myFLen - 1) - myGrad -= std::exp(grad + myBetaTp1U); - } - myXGradBuffer[i] = myGrad; - } +__global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, + const int* txtLen, const int* label, const acc_t* alpha, + const acc_t* beta, const int64_t* batchOffset, int64_t dictSize, + int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, + scalar_t* xGrad) { + const int tid = threadIdx.x; + const int u = blockIdx.x; + const int t = blockIdx.y; + const int batch = blockIdx.z; + const int64_t myFLen = audLen[batch]; + const int64_t myGLen = txtLen[batch] + 1; + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + + __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; + auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize; + auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize; + auto myAlpha = alpha + batch * maxFLen * maxGLen; + auto myBeta = beta + batch * maxFLen * maxGLen; + auto myLabel = label + batch * (maxGLen - 1); + + // Variabels for vectorization + scalar_t myXBuffer[V], myXGradBuffer[V]; + auto myXVec = reinterpret_cast(myX); + auto myXGradVec = reinterpret_cast(myXGrad); + auto myXBufferVec = reinterpret_cast(myXBuffer); + auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); + if (t < myFLen and u < myGLen) { + // load and store shared variables in SMEM + if (tid == 0) { + commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; + myBetaTU = myBeta[t * maxGLen + u]; + if (t != myFLen - 1) myBetaTp1U = myBeta[(t + 1) * maxGLen + u]; + if (u != myGLen - 1) { + myBetaTUp1 = myBeta[t * maxGLen + u + 1]; + myLabelShared = myLabel[u]; + } + } - // Store myXGrad in a vector form - myXGradVec[h0/V] = *myXGradBufferVec; - + __syncthreads(); + +#pragma unroll + for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) { + // Load myX in a vector form + *myXBufferVec = myXVec[h0 / V]; +// Do the update for a vector of input +#pragma unroll + for (int i = 0; i < V; ++i) { + auto h = h0 + i; + acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x)) + acc_t myGrad = std::exp(grad + myBetaTU); + if (u != myGLen - 1 and h == myLabelShared) { + myGrad -= std::exp(grad + myBetaTUp1); + } else if (h == blankIdx) { + if (t == myFLen - 1 and u == myGLen - 1) + myGrad -= std::exp(grad); + else if (t != myFLen - 1) + myGrad -= std::exp(grad + myBetaTp1U); } + myXGradBuffer[i] = myGrad; + } + + // Store myXGrad in a vector form + myXGradVec[h0 / V] = *myXGradBufferVec; } - else if (!packedInput){ - // In non-pack mode, need to make sure the gradients for don't-care regions are zero. - for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ - myXGradVec[h0/V] = 0; - } + } else if (!packedInput) { + // In non-pack mode, need to make sure the gradients for don't-care regions are zero. + for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) { + myXGradVec[h0 / V] = 0; } + } } - -std::vector transducer_loss_cuda_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput){ - - auto scalarType = x.scalar_type(); - auto tensorOpt = x.options(); - const int batchSize = label.size(0); - const int maxGLen = label.size(1) + 1; - const int dictSize = x.size(-1); - - TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, - "Expected blank index to be in the range of 0 to ", - dictSize-1, - ", but got ", - blankIdx); - TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, - "Got an invalid optimization level ", - opt); - - // The data type of alpha and beta will be resolved at dispatch time, - // hence defined here and assigned later - torch::Tensor alpha; - torch::Tensor beta; - torch::Tensor loss = torch::empty({batchSize}, tensorOpt); - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; - const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; - const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] { +std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, + torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, + int blankIdx, int opt, bool packedInput) { + auto scalarType = x.scalar_type(); + auto tensorOpt = x.options(); + const int batchSize = label.size(0); + const int maxGLen = label.size(1) + 1; + const int dictSize = x.size(-1); + + TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, "Expected blank index to be in the range of 0 to ", dictSize - 1, + ", but got ", blankIdx); + TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, "Got an invalid optimization level ", opt); + + // The data type of alpha and beta will be resolved at dispatch time, + // hence defined here and assigned later + torch::Tensor alpha; + torch::Tensor beta; + torch::Tensor loss = torch::empty({batchSize}, tensorOpt); + const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); + const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; + const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; + const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + scalarType, "transducer_loss_cuda_forward", ([&] { // resolve accumulation type using acc_t = at::acc_type; auto accType = c10::CppTypeToScalarType::value; @@ -601,167 +491,95 @@ std::vector transducer_loss_cuda_forward( // decide what kernel to launch based on the problem size // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla // kernel. - const auto smemSize = 2*maxGLen*sizeof(acc_t); - const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 - : (opt == -1) ? 1 : opt; + const auto smemSize = 2 * maxGLen * sizeof(acc_t); + const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 + : (opt == -1) ? 1 + : opt; const int threads = std::min(maxThreadPerBlock, maxGLen); - const dim3 blocks(2, batchSize, 1); + const dim3 blocks(2, batchSize, 1); if (optFallBack == 0) - transducer_loss_forward<<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - alpha.data_ptr(), - beta.data_ptr(), - loss.data_ptr()); + transducer_loss_forward<<>>( + x.data_ptr(), label.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), + batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr(), + beta.data_ptr(), loss.data_ptr()); else if (optFallBack == 1) - transducer_loss_batch_load_forward - <<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - alpha.data_ptr(), - beta.data_ptr(), - loss.data_ptr()); - - })); - C10_CUDA_CHECK(cudaGetLastError()); - - return {alpha, beta, loss}; + transducer_loss_batch_load_forward<<>>( + x.data_ptr(), label.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), + batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr(), + beta.data_ptr(), loss.data_ptr()); + })); + C10_CUDA_CHECK(cudaGetLastError()); + + return {alpha, beta, loss}; } - - - -torch::Tensor transducer_loss_cuda_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput){ - - auto dtype = x.scalar_type(); - torch::Tensor xGrad; - const int batchSize = label.size(0); - const int maxGLen = label.size(1) + 1; - const int dictSize = x.size(-1); - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; - const int warpSize = deviceProperties->warpSize; - const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fuseSoftmaxBackward){ - // alloc empty tensors for performance, hence need to ensure zeros are writtern to - // don't-care region in the kernel. - xGrad = torch::empty_like(x); - - // Would like each thread to work on 4 hidden units - const int workPerThread = 4; - // Don't want to have more than 128 threads per thread block - const int maxThreadPerElmt = std::min(128, maxThreadPerBlock); - const int threads = std::min(maxThreadPerElmt, std::max(warpSize, - (dictSize+workPerThread-1)/workPerThread)); - const dim3 blocks(maxGLen, maxFLen, batchSize); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { - using vec_t = uint64_t; - using acc_t = at::acc_type; - constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); - constexpr int vecAlignment = std::alignment_of::value; - // if all input and output tensors meet the alignment requirement - bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 - and reinterpret_cast(xGrad.data_ptr()) - % vecAlignment == 0; - - if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){ - transducer_loss_fused_vec_backward - <<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - } - else{ - transducer_loss_fused_backward<<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - - } +torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, + torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, + torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, + int opt, bool fuseSoftmaxBackward, bool packedInput) { + auto dtype = x.scalar_type(); + torch::Tensor xGrad; + const int batchSize = label.size(0); + const int maxGLen = label.size(1) + 1; + const int dictSize = x.size(-1); + const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); + const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; + const int warpSize = deviceProperties->warpSize; + const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fuseSoftmaxBackward) { + // alloc empty tensors for performance, hence need to ensure zeros are writtern to + // don't-care region in the kernel. + xGrad = torch::empty_like(x); + + // Would like each thread to work on 4 hidden units + const int workPerThread = 4; + // Don't want to have more than 128 threads per thread block + const int maxThreadPerElmt = std::min(128, maxThreadPerBlock); + const int threads = std::min(maxThreadPerElmt, std::max(warpSize, (dictSize + workPerThread - 1) / workPerThread)); + const dim3 blocks(maxGLen, maxFLen, batchSize); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dtype, "transducer_loss_cuda_backward", ([&] { + using vec_t = uint64_t; + using acc_t = at::acc_type; + constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); + constexpr int vecAlignment = std::alignment_of::value; + // if all input and output tensors meet the alignment requirement + bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 and + reinterpret_cast(xGrad.data_ptr()) % vecAlignment == 0; + + if (vectFactor > 1 and dictSize % vectFactor == 0 and memAlign) { + transducer_loss_fused_vec_backward<<>>( + x.data_ptr(), lossGrad.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), + label.data_ptr(), alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, + blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); + } else { + transducer_loss_fused_backward<<>>( + x.data_ptr(), lossGrad.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), + label.data_ptr(), alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, + blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); + } })); - } - else{ - // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize - // the tensor with all zeros. - xGrad = torch::zeros_like(x); - // don't launch more threads than needed. - const int threads = std::min(maxThreadPerBlock, maxGLen); - const dim3 blocks(maxFLen, batchSize); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { - using acc_t = at::acc_type; - transducer_loss_backward<<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - })); - } - C10_CUDA_CHECK(cudaGetLastError()); - - return xGrad; + } else { + // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize + // the tensor with all zeros. + xGrad = torch::zeros_like(x); + // don't launch more threads than needed. + const int threads = std::min(maxThreadPerBlock, maxGLen); + const dim3 blocks(maxFLen, batchSize); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { + using acc_t = at::acc_type; + transducer_loss_backward<<>>( + x.data_ptr(), lossGrad.data_ptr(), + audLen.data_ptr(), txtLen.data_ptr(), label.data_ptr(), + alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, + blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); + })); + } + C10_CUDA_CHECK(cudaGetLastError()); + + return xGrad; } diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp index cbf76c1f7..a0ba92097 100644 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ b/apex/contrib/csrc/xentropy/interface.cpp @@ -4,60 +4,52 @@ // CUDA forward declarations -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const bool half_to_float); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - const at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing); +std::vector softmax_xentropy_cuda(const at::Tensor &input, const at::Tensor &labels, const float smoothing, + const bool half_to_float); + +at::Tensor softmax_xentropy_backward_cuda(const at::Tensor &grad_loss, const at::Tensor &logits, + const at::Tensor &max_log_sum_exp, const at::Tensor &labels, + const float smoothing); // C++ interface #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const bool half_to_float) { - CHECK_CUDA(input); - CHECK_INPUT(labels); +std::vector softmax_xentropy_forward(const at::Tensor &input, const at::Tensor &labels, + const float smoothing, const bool half_to_float) { + CHECK_CUDA(input); + CHECK_INPUT(labels); - return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); + return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); } -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - const at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing) { - CHECK_CUDA(grad_loss); - CHECK_CUDA(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); +at::Tensor softmax_xentropy_backward(const at::Tensor &grad_loss, const at::Tensor &logits, + const at::Tensor &max_log_sum_exp, const at::Tensor &labels, + const float smoothing) { + CHECK_CUDA(grad_loss); + CHECK_CUDA(logits); + CHECK_INPUT(max_log_sum_exp); + CHECK_INPUT(labels); + + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::call_guard()); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::call_guard()); - // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables - py::object version = py::cast( + m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", + py::call_guard()); + m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", + py::call_guard()); + // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables + py::object version = py::cast( #ifdef XENTROPY_VER - XENTROPY_VER + XENTROPY_VER #else - std::string{} + std::string{} #endif - ); - m.attr("__version__") = version; + ); + m.attr("__version__") = version; } diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu index 444828ac6..e5d60d923 100644 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ b/apex/contrib/csrc/xentropy/xentropy_kernel.cu @@ -71,9 +71,9 @@ * POSSIBILITY OF SUCH DAMAGE. */ #include +#include #include -#include #include #include "type_shim.h" @@ -85,25 +85,21 @@ using TensorList = at::TensorList; using ScalarType = at::ScalarType; using at::acc_type; -template +template struct LogSoftMaxForwardEpilogue { __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} + : logsum(max_input + std::log(sum)) {} - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) : logsum(max_log_sum_exp) {} - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } + __device__ __forceinline__ OutT operator()(T input) const { return static_cast(input - logsum); } const AccumT logsum; }; -template +template struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} + __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) : sum(sum) {} __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { return static_cast(gradOutput - std::exp(static_cast(output)) * sum); @@ -112,74 +108,52 @@ struct LogSoftMaxBackwardEpilogue { const AccumT sum; }; - - const int max_threads = 1024; inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; + while (block_size < (max_block_size / 2)) block_size *= 2; // Launch at least a single warp - the kernel assumes that. block_size = std::max(block_size, static_cast(32)); return dim3(block_size); } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; - //////////////////////////////////////////////////////////////////////////////// // Regular kernel (fast when dim_size is large; requires inner_size == 1) //////////////////////////////////////////////////////////////////////////////// - template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } +struct MaxFloat { + __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { return ::max(max, (AccumT)v); } }; -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } +template +struct AddFloat { + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + v; } }; -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} +template +struct SumExpFloat { + __device__ __forceinline__ SumExpFloat(AccumT v) : max_k(v) {} - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + std::exp(v - max_k); } const AccumT max_k; }; -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ +template