Skip to content

Commit 92c1f49

Browse files
authored
[CK_BUILDER] Add grouped conv fwd ck tile traits (#3183)
* [CK BUILDER] Add grouped conv fwd ck tile traits * Update instance_traits_tile_grouped_convolution_forward.hpp * Update grouped_convolution_forward_kernel.hpp
1 parent b145a5f commit 92c1f49

18 files changed

+433
-15
lines changed

experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
1616
#include <ck/utility/loop_scheduler.hpp>
1717
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
18+
#include <ck_tile/ops/gemm.hpp>
19+
#include "ck_tile/ops/epilogue.hpp"
20+
#include <ck_tile/ops/grouped_convolution.hpp>
1821

1922
namespace ck_tile::reflect::conv {
2023

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "instance_traits.hpp"
7+
#include "instance_traits_util.hpp"
78
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
89

910
// Forward declaration to avoid circular dependency
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
// InstanceTraits specialization for GroupedConvolutionForwardKernel
4+
//
5+
// CRITICAL MAINTENANCE NOTE:
6+
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
7+
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
8+
// "In sync" means that the template parameter order, names, and types in the declaration below
9+
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
10+
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
11+
// difficult to diagnose. Always update both files together and review changes carefully.
12+
13+
#pragma once
14+
15+
#include "instance_traits.hpp"
16+
#include "instance_traits_util.hpp"
17+
18+
// Forward declaration to avoid circular dependency.
19+
namespace ck_tile::device {
20+
21+
template <typename GroupedConvTraitsType_,
22+
typename TilePartitioner_,
23+
typename GemmPipeline_,
24+
typename EpiloguePipeline_>
25+
struct GroupedConvolutionForwardKernel;
26+
27+
} // namespace ck_tile::device
28+
29+
namespace ck_tile {
30+
namespace reflect {
31+
32+
// Specialization for GroupedConvolutionForwardKernel
33+
template <typename GroupedConvTraitsType_,
34+
typename TilePartitioner_,
35+
typename GemmPipeline_,
36+
typename EpiloguePipeline_>
37+
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
38+
TilePartitioner_,
39+
GemmPipeline_,
40+
EpiloguePipeline_>>
41+
{
42+
// CK Tile Conv Traits
43+
// Spatial dimension
44+
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
45+
// Specialization
46+
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
47+
GroupedConvTraitsType_::ConvSpecialization;
48+
// DataType types
49+
using InLayout = typename GroupedConvTraitsType_::InLayout;
50+
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
51+
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
52+
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
53+
// Vector size
54+
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
55+
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
56+
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
57+
// Num Groups To Merge
58+
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
59+
// Split image (large tensors)
60+
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
61+
62+
// TilePartitioner
63+
// Block configuration
64+
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
65+
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
66+
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
67+
68+
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
69+
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
70+
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
71+
72+
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
73+
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
74+
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
75+
76+
// Data types
77+
using ADataType = typename GemmPipeline_::ADataType;
78+
using BDataType = typename GemmPipeline_::BDataType;
79+
// Gemm Pipeline
80+
using GemmPipeline = GemmPipeline_;
81+
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
82+
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
83+
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
84+
85+
// Epilogue Pipeline
86+
using AccDataType = typename EpiloguePipeline_::AccDataType;
87+
using EDataType = typename EpiloguePipeline_::ODataType;
88+
using DsDataType = typename EpiloguePipeline_::DsDataType;
89+
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
90+
91+
// Static member function to generate instance string
92+
static std::string instance_string()
93+
{
94+
std::ostringstream oss;
95+
96+
// Kernel type name
97+
oss << "GroupedConvolutionForwardKernel";
98+
99+
// Template parameters in exact order matching InstanceTraits member order
100+
oss << "<" << kSpatialDim; // 1. NDimSpatial
101+
oss << ","
102+
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
103+
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
104+
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
105+
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
106+
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
107+
oss << "," << kVectorSizeA; // 7. VectorSizeA
108+
oss << "," << kVectorSizeB; // 8. VectorSizeB
109+
oss << "," << kVectorSizeC; // 9. VectorSizeC
110+
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
111+
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
112+
oss << "," << kMPerBlock; // 12. MPerBlock
113+
oss << "," << kNPerBlock; // 13. NPerBlock
114+
oss << "," << kKPerBlock; // 14. KPerBlock
115+
oss << "," << kMWarp; // 15. MWarp
116+
oss << "," << kNWarp; // 16. NWarp
117+
oss << "," << kKWarp; // 17. KWarp
118+
oss << "," << kMWarpTile; // 18. MWarpTile
119+
oss << "," << kNWarpTile; // 19. NWarpTile
120+
oss << "," << kKWarpTile; // 20. KWarpTile
121+
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
122+
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
123+
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
124+
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
125+
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
126+
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
127+
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
128+
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
129+
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
130+
oss << ","
131+
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
132+
// CDEElementwiseOperation
133+
oss << ">";
134+
135+
return oss.str();
136+
}
137+
};
138+
139+
} // namespace reflect
140+
} // namespace ck_tile

experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
2929
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
3030
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
31+
#include <ck_tile/ops/gemm.hpp>
32+
#include "ck_tile/ops/epilogue.hpp"
33+
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
34+
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
3135

3236
namespace ck_tile::reflect::detail {
3337

@@ -38,7 +42,7 @@ namespace impl {
3842
template <typename T>
3943
consteval std::string_view type_name_impl()
4044
{
41-
if constexpr(std::is_same_v<T, ck::half_t>)
45+
if constexpr(std::is_same_v<T, ck::half_t> || std::is_same_v<T, ck_tile::half_t>)
4246
return "fp16";
4347
else if constexpr(std::is_same_v<T, float>)
4448
return "fp32";
@@ -50,11 +54,11 @@ consteval std::string_view type_name_impl()
5054
return "s8";
5155
else if constexpr(std::is_same_v<T, int32_t>)
5256
return "s32";
53-
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
57+
else if constexpr(std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck_tile::bf16_t>)
5458
return "bf16";
55-
else if constexpr(std::is_same_v<T, ck::f8_t>)
59+
else if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck_tile::fp8_t>)
5660
return "fp8";
57-
else if constexpr(std::is_same_v<T, ck::bf8_t>)
61+
else if constexpr(std::is_same_v<T, ck::bf8_t> || std::is_same_v<T, ck_tile::bf8_t>)
5862
return "bf8";
5963
else
6064
return std::string_view{}; // Return empty for supported types
@@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule
168172
}
169173
}
170174

175+
constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched)
176+
{
177+
using enum ck_tile::GemmPipelineScheduler;
178+
switch(sched)
179+
{
180+
case Default: return "Default";
181+
case Intrawave: return "Intrawave";
182+
case Interwave: return "Interwave";
183+
}
184+
}
185+
171186
// Convert BlockGemmPipelineVersion enum to string
172187
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
173188
{
@@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
206221
}
207222
}
208223

224+
// Convert TailNumber enum to string
225+
constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num)
226+
{
227+
using enum ck_tile::TailNumber;
228+
switch(tail_num)
229+
{
230+
case Odd: return "Odd";
231+
case Even: return "Even";
232+
case One: return "One";
233+
case Two: return "Two";
234+
case Three: return "Three";
235+
case Four: return "Four";
236+
case Five: return "Five";
237+
case Six: return "Six";
238+
case Seven: return "Seven";
239+
case Empty: return "Empty";
240+
case Full: return "Full";
241+
}
242+
}
243+
209244
// Convert std::array to string
210245
template <typename T, std::size_t N>
211246
inline std::string array_to_string(const std::array<T, N>& arr)
@@ -356,17 +391,53 @@ constexpr std::string tuple_name()
356391
}(static_cast<T*>(nullptr));
357392
}
358393

394+
template <typename T>
395+
requires requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
396+
constexpr std::string tuple_name()
397+
{
398+
return []<typename... Ts>(ck_tile::tuple<Ts...>*) constexpr {
399+
if constexpr(sizeof...(Ts) == 0)
400+
{
401+
return std::string("EmptyTuple");
402+
}
403+
else if constexpr((IsLayoutType<Ts> && ...))
404+
{
405+
// Lambda wrapper for layout_name
406+
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
407+
return detail::build_list_string<decltype(layout_name_fn), Ts...>("tuple",
408+
layout_name_fn);
409+
}
410+
else if constexpr((IsDataType<Ts> && ...))
411+
{
412+
// Lambda wrapper for type_name
413+
auto type_name_fn = []<typename U>() { return type_name<U>(); };
414+
return detail::build_list_string<decltype(type_name_fn), Ts...>("tuple", type_name_fn);
415+
}
416+
else
417+
{
418+
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
419+
"tuple elements must be all layouts or all data types, not mixed");
420+
return std::string{}; // unreachable
421+
}
422+
}(static_cast<T*>(nullptr));
423+
}
424+
359425
// Concept to check if a type is a ck::Tuple
360426
template <typename T>
361427
concept IsCkTuple =
362428
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
363429

430+
// Concept to check if a type is a ck_tile::tuple
431+
template <typename T>
432+
concept IsCkTileTuple =
433+
requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
434+
364435
// Deduces whether to use tuple_name or type_name
365436
// Handles both scalar data types and ck::Tuple types
366437
template <typename T>
367438
constexpr std::string type_or_type_tuple_name()
368439
{
369-
if constexpr(IsCkTuple<T>)
440+
if constexpr(IsCkTuple<T> || IsCkTileTuple<T>)
370441
{
371442
return tuple_name<T>();
372443
}

0 commit comments

Comments
 (0)