diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index a20171d8b6..aa534d0dd9 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -533,7 +533,7 @@ std::vector get_candidate_configs_sm110( #ifdef FAST_BUILD // Fast build disables all configs except this return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, ClusterShape::ClusterShape_1x1x1}}; #else std::vector candidate_configs; @@ -574,7 +574,7 @@ std::vector get_candidate_configs_sm110( std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; for (auto tile : base) { - CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, cluster}; candidate_configs.push_back(config); } diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 708406aab6..dd0607e897 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -796,7 +796,9 @@ void MoeGemmRunner::dispatchToArch( // cases with small numbers of tokens SM80 is faster. We check here to see which is selected if (inputs.gemm_config.sm_version >= 90) { // Check the major version of the SM matches - TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, + TLLM_CHECK_WITH_INFO((inputs.gemm_config.sm_version / 10 == sm_ / 10) || + // allow sm100 configs to run on sm110 as well + (inputs.gemm_config.sm_version / 10 == 10 && sm_ / 10 == 11), "Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_); TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr,