|
| 1 | +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | +// InstanceTraits specialization for GroupedConvolutionForwardKernel |
| 4 | +// |
| 5 | +// CRITICAL MAINTENANCE NOTE: |
| 6 | +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: |
| 7 | +// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp |
| 8 | +// "In sync" means that the template parameter order, names, and types in the declaration below |
| 9 | +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter |
| 10 | +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are |
| 11 | +// difficult to diagnose. Always update both files together and review changes carefully. |
| 12 | + |
| 13 | +#pragma once |
| 14 | + |
| 15 | +#include "instance_traits.hpp" |
| 16 | +#include "instance_traits_util.hpp" |
| 17 | + |
| 18 | +// Forward declaration to avoid circular dependency. |
| 19 | +namespace ck_tile::device { |
| 20 | + |
| 21 | +template <typename GroupedConvTraitsType_, |
| 22 | + typename TilePartitioner_, |
| 23 | + typename GemmPipeline_, |
| 24 | + typename EpiloguePipeline_> |
| 25 | +struct GroupedConvolutionForwardKernel; |
| 26 | + |
| 27 | +} // namespace ck_tile::device |
| 28 | + |
| 29 | +namespace ck_tile { |
| 30 | +namespace reflect { |
| 31 | + |
| 32 | +// Specialization for GroupedConvolutionForwardKernel |
| 33 | +template <typename GroupedConvTraitsType_, |
| 34 | + typename TilePartitioner_, |
| 35 | + typename GemmPipeline_, |
| 36 | + typename EpiloguePipeline_> |
| 37 | +struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_, |
| 38 | + TilePartitioner_, |
| 39 | + GemmPipeline_, |
| 40 | + EpiloguePipeline_>> |
| 41 | +{ |
| 42 | + // CK Tile Conv Traits |
| 43 | + // Spatial dimension |
| 44 | + static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial; |
| 45 | + // Specialization |
| 46 | + static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization = |
| 47 | + GroupedConvTraitsType_::ConvSpecialization; |
| 48 | + // DataType types |
| 49 | + using InLayout = typename GroupedConvTraitsType_::InLayout; |
| 50 | + using WeiLayout = typename GroupedConvTraitsType_::WeiLayout; |
| 51 | + using DsLayout = typename GroupedConvTraitsType_::DsLayout; |
| 52 | + using OutLayout = typename GroupedConvTraitsType_::OutLayout; |
| 53 | + // Vector size |
| 54 | + static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA; |
| 55 | + static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB; |
| 56 | + static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC; |
| 57 | + // Num Groups To Merge |
| 58 | + static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; |
| 59 | + // Split image (large tensors) |
| 60 | + static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; |
| 61 | + |
| 62 | + // TilePartitioner |
| 63 | + // Block configuration |
| 64 | + static constexpr int kMPerBlock = TilePartitioner_::MPerBlock; |
| 65 | + static constexpr int kNPerBlock = TilePartitioner_::NPerBlock; |
| 66 | + static constexpr int kKPerBlock = TilePartitioner_::KPerBlock; |
| 67 | + |
| 68 | + static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{}); |
| 69 | + static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{}); |
| 70 | + static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{}); |
| 71 | + |
| 72 | + static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{}); |
| 73 | + static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{}); |
| 74 | + static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{}); |
| 75 | + |
| 76 | + // Data types |
| 77 | + using ADataType = typename GemmPipeline_::ADataType; |
| 78 | + using BDataType = typename GemmPipeline_::BDataType; |
| 79 | + // Gemm Pipeline |
| 80 | + using GemmPipeline = GemmPipeline_; |
| 81 | + static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler; |
| 82 | + static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer; |
| 83 | + static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups; |
| 84 | + |
| 85 | + // Epilogue Pipeline |
| 86 | + using AccDataType = typename EpiloguePipeline_::AccDataType; |
| 87 | + using EDataType = typename EpiloguePipeline_::ODataType; |
| 88 | + using DsDataType = typename EpiloguePipeline_::DsDataType; |
| 89 | + using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise; |
| 90 | + |
| 91 | + // Static member function to generate instance string |
| 92 | + static std::string instance_string() |
| 93 | + { |
| 94 | + std::ostringstream oss; |
| 95 | + |
| 96 | + // Kernel type name |
| 97 | + oss << "GroupedConvolutionForwardKernel"; |
| 98 | + |
| 99 | + // Template parameters in exact order matching InstanceTraits member order |
| 100 | + oss << "<" << kSpatialDim; // 1. NDimSpatial |
| 101 | + oss << "," |
| 102 | + << ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization |
| 103 | + oss << "," << detail::layout_name<InLayout>(); // 3. InLayout |
| 104 | + oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout |
| 105 | + oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout |
| 106 | + oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout |
| 107 | + oss << "," << kVectorSizeA; // 7. VectorSizeA |
| 108 | + oss << "," << kVectorSizeB; // 8. VectorSizeB |
| 109 | + oss << "," << kVectorSizeC; // 9. VectorSizeC |
| 110 | + oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge |
| 111 | + oss << "," << kEnableSplitImage; // 11. EnableSplitImage |
| 112 | + oss << "," << kMPerBlock; // 12. MPerBlock |
| 113 | + oss << "," << kNPerBlock; // 13. NPerBlock |
| 114 | + oss << "," << kKPerBlock; // 14. KPerBlock |
| 115 | + oss << "," << kMWarp; // 15. MWarp |
| 116 | + oss << "," << kNWarp; // 16. NWarp |
| 117 | + oss << "," << kKWarp; // 17. KWarp |
| 118 | + oss << "," << kMWarpTile; // 18. MWarpTile |
| 119 | + oss << "," << kNWarpTile; // 19. NWarpTile |
| 120 | + oss << "," << kKWarpTile; // 20. KWarpTile |
| 121 | + oss << "," << detail::type_name<ADataType>(); // 21. ADataType |
| 122 | + oss << "," << detail::type_name<BDataType>(); // 22. BDataType |
| 123 | + oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer |
| 124 | + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched |
| 125 | + oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups |
| 126 | + oss << "," << kNumWaveGroups; // 26. NumWaveGroups |
| 127 | + oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType |
| 128 | + oss << "," << detail::type_name<EDataType>(); // 28. EDataType |
| 129 | + oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType |
| 130 | + oss << "," |
| 131 | + << detail::elementwise_op_name<CDEElementwiseOperation>(); // 30. |
| 132 | + // CDEElementwiseOperation |
| 133 | + oss << ">"; |
| 134 | + |
| 135 | + return oss.str(); |
| 136 | + } |
| 137 | +}; |
| 138 | + |
| 139 | +} // namespace reflect |
| 140 | +} // namespace ck_tile |
0 commit comments