Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include <ck_tile/ops/grouped_convolution.hpp>

namespace ck_tile::reflect::conv {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"

// Forward declaration to avoid circular dependency
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: New copyright format shouldn't contain the year.

// InstanceTraits specialization for GroupedConvolutionForwardKernel
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.

#pragma once

#include "instance_traits.hpp"
#include "instance_traits_util.hpp"

// Forward declaration to avoid circular dependency.
namespace ck_tile::device {

template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel;

} // namespace ck_tile::device

namespace ck_tile {
namespace reflect {

// Specialization for GroupedConvolutionForwardKernel
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
// Specialization
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
// DataType types
using InLayout = typename GroupedConvTraitsType_::InLayout;
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
// Vector size
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
// Num Groups To Merge
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;

// TilePartitioner
// Block configuration
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;

static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});

static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});

// Data types
using ADataType = typename GemmPipeline_::ADataType;
using BDataType = typename GemmPipeline_::BDataType;
// Gemm Pipeline
using GemmPipeline = GemmPipeline_;
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;

// Epilogue Pipeline
using AccDataType = typename EpiloguePipeline_::AccDataType;
using EDataType = typename EpiloguePipeline_::ODataType;
using DsDataType = typename EpiloguePipeline_::DsDataType;
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;

// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;

// Kernel type name
oss << "GroupedConvolutionForwardKernel";

// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << ","
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
oss << "," << kVectorSizeA; // 7. VectorSizeA
oss << "," << kVectorSizeB; // 8. VectorSizeB
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
// CDEElementwiseOperation
oss << ">";

return oss.str();
}
};

} // namespace reflect
} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"

namespace ck_tile::reflect::detail {

Expand All @@ -38,7 +42,7 @@ namespace impl {
template <typename T>
consteval std::string_view type_name_impl()
{
if constexpr(std::is_same_v<T, ck::half_t>)
if constexpr(std::is_same_v<T, ck::half_t> || std::is_same_v<T, ck_tile::half_t>)
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
Expand All @@ -50,11 +54,11 @@ consteval std::string_view type_name_impl()
return "s8";
else if constexpr(std::is_same_v<T, int32_t>)
return "s32";
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
else if constexpr(std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck_tile::bf16_t>)
return "bf16";
else if constexpr(std::is_same_v<T, ck::f8_t>)
else if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck_tile::fp8_t>)
return "fp8";
else if constexpr(std::is_same_v<T, ck::bf8_t>)
else if constexpr(std::is_same_v<T, ck::bf8_t> || std::is_same_v<T, ck_tile::bf8_t>)
return "bf8";
else
return std::string_view{}; // Return empty for supported types
Expand Down Expand Up @@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule
}
}

constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched)
{
using enum ck_tile::GemmPipelineScheduler;
switch(sched)
{
case Default: return "Default";
case Intrawave: return "Intrawave";
case Interwave: return "Interwave";
}
}

// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
Expand Down Expand Up @@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
}
}

// Convert TailNumber enum to string
constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num)
{
using enum ck_tile::TailNumber;
switch(tail_num)
{
case Odd: return "Odd";
case Even: return "Even";
case One: return "One";
case Two: return "Two";
case Three: return "Three";
case Four: return "Four";
case Five: return "Five";
case Six: return "Six";
case Seven: return "Seven";
case Empty: return "Empty";
case Full: return "Full";
}
}

// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)
Expand Down Expand Up @@ -356,17 +391,53 @@ constexpr std::string tuple_name()
}(static_cast<T*>(nullptr));
}

template <typename T>
requires requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string tuple_name()
{
return []<typename... Ts>(ck_tile::tuple<Ts...>*) constexpr {
if constexpr(sizeof...(Ts) == 0)
{
return std::string("EmptyTuple");
}
else if constexpr((IsLayoutType<Ts> && ...))
{
// Lambda wrapper for layout_name
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
return detail::build_list_string<decltype(layout_name_fn), Ts...>("tuple",
layout_name_fn);
}
else if constexpr((IsDataType<Ts> && ...))
{
// Lambda wrapper for type_name
auto type_name_fn = []<typename U>() { return type_name<U>(); };
return detail::build_list_string<decltype(type_name_fn), Ts...>("tuple", type_name_fn);
}
else
{
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
"tuple elements must be all layouts or all data types, not mixed");
return std::string{}; // unreachable
}
}(static_cast<T*>(nullptr));
}

// Concept to check if a type is a ck::Tuple
template <typename T>
concept IsCkTuple =
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };

// Concept to check if a type is a ck_tile::tuple
template <typename T>
concept IsCkTileTuple =
requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };

// Deduces whether to use tuple_name or type_name
// Handles both scalar data types and ck::Tuple types
template <typename T>
constexpr std::string type_or_type_tuple_name()
{
if constexpr(IsCkTuple<T>)
if constexpr(IsCkTuple<T> || IsCkTileTuple<T>)
{
return tuple_name<T>();
}
Expand Down
Loading