21#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
22#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
29template <
bool Use2LDS,
30 typename GridwiseGemm,
31 bool HasMainKBlockLoop,
36#if CK_USE_LAUNCH_BOUNDS
42#if defined(__gfx950__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
45 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
47 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
52 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
53 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
54 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
63template <
bool Use2LDS,
64 typename GridwiseGemm,
65 bool HasMainKBlockLoop,
70#if CK_USE_LAUNCH_BOUNDS
76#if defined(__gfx950__)
77 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
81 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
82 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
84 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
86 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
87 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
88 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
90 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
91 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
102template <
typename ALayout,
106 typename AScaleDataType,
108 typename BScaleDataType,
109 typename AccDataType,
110 typename CShuffleDataType,
112 typename AElementwiseOperation,
113 typename BElementwiseOperation,
114 typename CElementwiseOperation,
127 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
128 typename ABlockTransferThreadClusterArrangeOrder,
129 typename ABlockTransferSrcAccessOrder,
130 index_t ABlockTransferSrcVectorDim,
131 index_t ABlockTransferSrcScalarPerVector,
132 index_t ABlockTransferDstScalarPerVector_AK1,
133 bool AThreadTransferSrcResetCoordinateAfterRun,
135 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
136 typename BBlockTransferThreadClusterArrangeOrder,
137 typename BBlockTransferSrcAccessOrder,
138 index_t BBlockTransferSrcVectorDim,
139 index_t BBlockTransferSrcScalarPerVector,
140 index_t BBlockTransferDstScalarPerVector_BK1,
141 bool BThreadTransferSrcResetCoordinateAfterRun,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
151 typename ComputeTypeB =
153 bool PermuteA =
false,
154 bool PermuteB =
false>
226 auto K_t = K_Batch * KPerBlock;
227 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
232 auto K_t = K_Batch * KPerBlock;
233 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
238 auto K_t = K_Batch * KPerBlock;
239 return (K + K_t - 1) / K_t * KPerBlock;
245 auto K_t = K_Batch * KReadVec;
246 return (K + K_t - 1) / K_t * KReadVec;
259 template <
index_t MNXdlPerWave,
263 typename TileDesc_K0_MN_K1>
291 const auto a_grid_desc_mraw_kraw = [&]() {
304 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
305 GemmSpec == GemmSpecialization::MNKPadding)
308 const auto a_grid_desc_m_k =
322 return a_grid_desc_ak0_m_ak1;
324 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
325 GemmSpec == GemmSpecialization::MNPadding)
329 a_grid_desc_mraw_kraw,
336 a_grid_desc_ak0_m_ak1,
344 a_grid_desc_permuted,
353 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
354 GemmSpec == GemmSpecialization::NKPadding)
358 a_grid_desc_mraw_kraw,
370 return a_grid_desc_ak0_m_ak1;
376 a_grid_desc_mraw_kraw,
383 a_grid_desc_ak0_m_ak1,
391 a_grid_desc_permuted,
406 const auto b_grid_desc_nraw_kraw = [&]() {
420 GemmSpec != GemmSpecialization::Default),
421 "pk_i4_t does not support padding");
423 (GemmSpec != GemmSpecialization::Default &&
424 GemmSpec != GemmSpecialization::MPadding)),
425 "f4x2_pk_t does not support K padding");
430 GemmSpec != GemmSpecialization::Default),
431 "Packed F6 types do not support padding");
433 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
434 GemmSpec == GemmSpecialization::MNKPadding)
437 const auto b_grid_desc_n_k =
451 return b_grid_desc_bk0_n_bk1;
453 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
454 GemmSpec == GemmSpecialization::MNPadding)
458 b_grid_desc_nraw_kraw,
464 return b_grid_desc_bk0_n_bk1;
466 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
467 GemmSpec == GemmSpecialization::MKPadding)
471 b_grid_desc_nraw_kraw,
483 return b_grid_desc_bk0_n_bk1;
487 if constexpr(!PermuteB)
491 b_grid_desc_nraw_kraw,
499 b_grid_desc_bk0_n_bk1,
507 b_grid_desc_permuted,
520 constexpr index_t BK01 = KPerBlock / BK1Value;
522 const index_t BK0_ = StrideB / BK1Value;
523 const index_t BK00 = BK0_ / BK01;
525 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
529 b_grid_desc_bk00_n_bk01_bk1_permute,
536 return b_grid_desc_bk0_n_bk1_permute;
541 template <
typename ABlockDesc_AK0_M_AK1>
542 __host__ __device__
static constexpr auto
545 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
548 ABlockDesc_AK0_M_AK1{});
551 template <
typename BBlockDesc_BK0_N_BK1>
552 __host__ __device__
static constexpr auto
555 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
558 BBlockDesc_BK0_N_BK1{});
561 __host__ __device__
static auto
564 const auto c_grid_desc_mraw_nraw = [&]() {
584 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
585 GemmSpec == GemmSpecialization::MNKPadding)
594 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
595 GemmSpec == GemmSpecialization::MKPadding)
599 c_grid_desc_mraw_nraw,
604 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
605 GemmSpec == GemmSpecialization::NKPadding)
609 c_grid_desc_mraw_nraw,
617 return c_grid_desc_mraw_nraw;
655 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
659 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
660 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
661 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
687 const AScaleDataType* p_a_scale_grid_,
688 const BDataType* p_b_grid_,
689 const BScaleDataType* p_b_scale_grid_,
690 CDataType* p_c_grid_,
700 AElementwiseOperation a_element_op_,
701 BElementwiseOperation b_element_op_,
702 CElementwiseOperation c_element_op_,
703 bool is_reduce_ =
false)
767 if constexpr(!PermuteB)
773 const int k0_offset = karg.
KRead * karg.
N;
786 if(k_id < (karg.
KBatch - 1))
814 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
815 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
816 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
830 constexpr auto a_lds_block_desc =
842 return a_lds_block_desc_permuted;
849 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
850 constexpr auto M1 = MPerBlock / M0;
852 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
853 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
854 constexpr auto KThreadRead = WaveSize / MPerXdl;
855 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
857 constexpr auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
859 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
860 constexpr auto KThreadReadPerm =
861 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
862 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
866 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
868 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
870 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
876 Number<kfold * M0 / mpair>{},
895 a_lds_block_desc_permuted,
917 a_lds_block_desc_unmerged,
920 Number<KThreadWrite / kfold / KThreadReadPerm>{},
929 return a_lds_block_desc_ak0_m_ak1;
935 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
936 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
937 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
949 constexpr auto b_lds_block_desc =
961 return b_lds_block_desc_permuted;
965 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
966 constexpr auto N1 = NPerBlock / N0;
968 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
969 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
970 constexpr auto KThreadRead = WaveSize / NPerXdl;
971 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
973 constexpr auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
975 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
976 constexpr auto KThreadReadPerm =
977 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
978 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
982 constexpr auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
984 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
986 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
992 Number<kfold * N0 / npair>{},
1011 b_lds_block_desc_permuted,
1033 b_lds_block_desc_unmerged,
1036 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1045 return b_lds_block_desc_bk0_n_bk1;
1051 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1052 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1054 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1061 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1082 ABlockTransferSrcScalarPerVector,
1083 BBlockTransferSrcScalarPerVector,
1103 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1106 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1109 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1112 constexpr auto c_block_size =
1113 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1115 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1116 b_block_space_size_aligned *
sizeof(BDataType)),
1117 c_block_size *
sizeof(CShuffleDataType));
1125 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1126 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1127 "Invalid tuning param!");
1129 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1130 "KPerBlock should be multiple of ScaleBlockSize");
1138 if(!(karg.M % MPerBlock == 0))
1142 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1143 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1156 if(!(karg.N % NPerBlock == 0))
1160 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1161 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1173 auto K_t = karg.KBatch * KPerBlock;
1174 if(!(karg.K % K_t == 0))
1178 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1179 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1180 <<
", in function: " << __func__ << std::endl;
1188 auto K_t = karg.KBatch * KReadVec;
1190 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1198 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1202 std::cout <<
"Arg K (" << karg.K
1203 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1204 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1205 << __LINE__ <<
", in function: " << __func__ << std::endl;
1212 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1216 std::cout <<
"Arg M (" << karg.M
1217 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1218 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1219 << __LINE__ <<
", in function: " << __func__ << std::endl;
1227 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1231 std::cout <<
"Arg N (" << karg.N
1232 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1233 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1234 << __LINE__ <<
", in function: " << __func__ << std::endl;
1241 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1245 std::cout <<
"Arg K (" << karg.K
1246 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1247 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1248 << __LINE__ <<
", in function: " << __func__ << std::endl;
1256 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1260 std::cout <<
"Arg N (" << karg.N
1261 <<
") value is not a multiple of "
1262 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1263 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1264 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1272 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1276 std::cout <<
"Arg M (" << karg.M
1277 <<
") value is not a multiple of "
1278 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1279 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1280 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1292 if(!karg.IsReduceAdd())
1296 std::cout <<
" KBatch: " << karg.KBatch <<
" > 1 is not support yet" << __FILE__
1297 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1307 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1311 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1323 const index_t num_loop = K / KPerBlock;
1325 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1330 const index_t num_loop = K / KPerBlock;
1332 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1335 template <
typename CGr
idDesc>
1337 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1346 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1358 "A scale pack data type too large!");
1360 "B scale pack data type too large!");
1364 "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1366 template <
typename AGridDesc_AK0_M_K1,
1367 typename AScaleGridDesc_AM_AK,
1368 typename BGridDesc_BK0_N_K1,
1369 typename BScaleGridDesc_BN_AK,
1370 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1371 bool HasMainKBlockLoop,
1374 __device__
static void Run(
const ADataType* p_a_grid,
1375 const AScaleDataType* p_a_scale_grid,
1376 const BDataType* p_b_grid,
1377 const BScaleDataType* p_b_scale_grid,
1378 CDataType* p_c_grid,
1380 const Problem& problem,
1381 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1382 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1383 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1384 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1385 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1386 c_grid_desc_mblock_mperblock_nblock_nperblock)
1389 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1391 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1393 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1397 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1401 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1403 const CElementwiseOperation c_element_op{};
1406 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1408 const auto block_work_idx =
1411 if(!block_2_ctile_map.ValidCTileIndex(
1413 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1414 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1419 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1420 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1423 const index_t m_block_data_idx_on_grid =
1424 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1426 const index_t n_block_data_idx_on_grid =
1427 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1438 auto a_blockwise_copy =
1441 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1442 ABlockTransferThreadClusterArrangeOrder,
1445 decltype(a_grid_desc_ak0_m_ak1),
1446 decltype(a_block_desc_ak0_m_ak1),
1447 ABlockTransferSrcAccessOrder,
1448 ABlockTransferSrcVectorDim,
1450 ABlockTransferSrcScalarPerVector>(
1451 a_grid_desc_ak0_m_ak1,
1453 a_block_desc_ak0_m_ak1,
1457 auto b_blockwise_copy =
1460 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1461 BBlockTransferThreadClusterArrangeOrder,
1464 decltype(b_grid_desc_bk0_n_bk1),
1465 decltype(b_block_desc_bk0_n_bk1),
1466 BBlockTransferSrcAccessOrder,
1467 BBlockTransferSrcVectorDim,
1469 BBlockTransferSrcScalarPerVector>(
1470 b_grid_desc_bk0_n_bk1,
1472 b_block_desc_bk0_n_bk1,
1477 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1481 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1484 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1485 a_block_space_size_aligned *
sizeof(ADataType)),
1486 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1492 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1494 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1496 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1497 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1517 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1518 const auto waveId_m = wave_idx[
I0];
1519 const auto waveId_n = wave_idx[
I1];
1527 auto thread_offset_shuffled =
1530 auto a_thread_offset_m = waveId_m;
1535 decltype(a_scale_grid_desc_am_ak),
1536 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1542 true>(a_scale_grid_desc_am_ak,
1547 auto b_thread_offset_n = waveId_n;
1552 decltype(b_scale_grid_desc_bn_ak),
1553 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1559 true>(b_scale_grid_desc_bn_ak,
1565 a_block_desc_ak0_m_ak1,
1569 a_block_slice_copy_step,
1570 b_grid_desc_bk0_n_bk1,
1571 b_block_desc_bk0_n_bk1,
1575 b_block_slice_copy_step,
1577 a_scale_grid_desc_am_ak,
1578 a_scale_thread_copy,
1580 b_scale_grid_desc_bn_ak,
1581 b_scale_thread_copy,
1583 num_k_block_main_loop);
1587 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1588 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1590 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1591 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1594 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1595 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1598 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1599 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1603 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1604 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1606 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1607 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1608 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1609 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1610 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1611 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1612 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1613 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1614 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1615 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1617 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1621 static_cast<CShuffleDataType*
>(p_shared),
1622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1651 const auto c_thread_mtx_on_block =
1652 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1654 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1655 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1657 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1663 const auto m_thread_data_on_block_idx =
1664 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1667 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1673 const auto n_thread_data_on_block_idx =
1674 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1681 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1682 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1685 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1694 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1699 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1702 m_thread_data_on_block_idx[
I1],
1703 n_thread_data_on_block_idx[
I1],
1704 m_thread_data_on_block_idx[
I2],
1705 n_thread_data_on_block_idx[
I2],
1706 m_thread_data_on_block_idx[
I3],
1707 m_thread_data_on_block_idx[
I4],
1708 m_thread_data_on_block_idx[
I5],
1709 n_thread_data_on_block_idx[
I3]),
1715 CElementwiseOperation,
1716 CGlobalMemoryDataOperation,
1718 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1720 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1721 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1725 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1726 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1729 CShuffleBlockTransferScalarPerVector_NPerBlock,
1732 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1734 c_grid_desc_mblock_mperblock_nblock_nperblock,
1739 constexpr auto sfc_c_vgpr =
1750 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1752 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1763 constexpr auto sfc_c_global =
1767 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1769 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1771 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1773 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1780 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1781 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1783 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1784 c_shuffle_block_buf);
1790 c_shuffle_block_copy_lds_to_global.Run(
1791 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1792 c_shuffle_block_buf,
1793 c_grid_desc_mblock_mperblock_nblock_nperblock,
1796 if constexpr(access_id < num_access - 1)
1798 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1801 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1802 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1808 template <
bool HasMainKBlockLoop,
1811 __device__
static void Run(
const ADataType* p_a_grid,
1812 const AScaleDataType* p_a_scale_grid,
1813 const BDataType* p_b_grid,
1814 const BScaleDataType* p_b_scale_grid,
1815 CDataType* p_c_grid,
1817 const Problem& problem)
1820 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1822 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1824 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1825 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1827 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1831 const auto Padded_Scale_M =
1855 Run<
decltype(a_grid_desc_ak0_m_ak1),
1856 decltype(a_scale_grid_desc_am_ak),
1857 decltype(b_grid_desc_bk0_n_bk1),
1858 decltype(b_scale_grid_desc_bn_ak),
1859 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1861 CGlobalMemoryDataOperation,
1869 a_grid_desc_ak0_m_ak1,
1870 a_scale_grid_desc_am_ak,
1871 b_grid_desc_bk0_n_bk1,
1872 b_scale_grid_desc_bn_ak,
1873 c_grid_desc_mblock_mperblock_nblock_nperblock);
1876 template <
typename AGridDesc_AK0_M_K1,
1877 typename AScaleGridDesc_AM_AK,
1878 typename BGridDesc_BK0_N_K1,
1879 typename BScaleGridDesc_BN_AK,
1880 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1881 bool HasMainKBlockLoop,
1884 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1885 const AScaleDataType* p_a_scale_grid,
1886 const BDataType* p_b_grid,
1887 const BScaleDataType* p_b_scale_grid,
1888 CDataType* p_c_grid,
1891 const Problem& problem,
1892 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1893 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1894 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1895 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1896 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1897 c_grid_desc_mblock_mperblock_nblock_nperblock)
1900 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1902 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1904 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1908 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1912 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1914 const CElementwiseOperation c_element_op{};
1917 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1919 const auto block_work_idx =
1922 if(!block_2_ctile_map.ValidCTileIndex(
1924 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1925 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1930 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1931 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1934 const index_t m_block_data_idx_on_grid =
1935 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1937 const index_t n_block_data_idx_on_grid =
1938 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1949 auto a_blockwise_copy =
1952 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1953 ABlockTransferThreadClusterArrangeOrder,
1956 decltype(a_grid_desc_ak0_m_ak1),
1957 decltype(a_block_desc_ak0_m_ak1),
1958 ABlockTransferSrcAccessOrder,
1959 ABlockTransferSrcVectorDim,
1961 ABlockTransferSrcScalarPerVector>(
1962 a_grid_desc_ak0_m_ak1,
1964 a_block_desc_ak0_m_ak1,
1968 auto b_blockwise_copy =
1971 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1972 BBlockTransferThreadClusterArrangeOrder,
1975 decltype(b_grid_desc_bk0_n_bk1),
1976 decltype(b_block_desc_bk0_n_bk1),
1977 BBlockTransferSrcAccessOrder,
1978 BBlockTransferSrcVectorDim,
1980 BBlockTransferSrcScalarPerVector>(
1981 b_grid_desc_bk0_n_bk1,
1983 b_block_desc_bk0_n_bk1,
1988 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1991 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1995 a_block_space_size_aligned *
sizeof(ADataType)),
1996 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1999 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2003 a_block_space_size_aligned *
sizeof(ADataType)),
2004 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2006 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2007 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2013 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2015 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2017 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2018 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2038 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2039 const auto waveId_m = wave_idx[
I0];
2040 const auto waveId_n = wave_idx[
I1];
2048 auto thread_offset_shuffled =
2051 auto a_thread_offset_m = waveId_m;
2056 decltype(a_scale_grid_desc_am_ak),
2057 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2063 true>(a_scale_grid_desc_am_ak,
2068 auto b_thread_offset_n = waveId_n;
2073 decltype(b_scale_grid_desc_bn_ak),
2074 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2080 true>(b_scale_grid_desc_bn_ak,
2086 a_block_desc_ak0_m_ak1,
2090 a_block_slice_copy_step,
2091 b_grid_desc_bk0_n_bk1,
2092 b_block_desc_bk0_n_bk1,
2096 b_block_slice_copy_step,
2098 a_scale_grid_desc_am_ak,
2099 a_scale_thread_copy,
2101 b_scale_grid_desc_bn_ak,
2102 b_scale_thread_copy,
2104 num_k_block_main_loop);
2108 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2109 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2111 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
2112 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2115 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2116 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2119 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2120 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2124 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2125 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2127 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2128 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2129 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2130 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2131 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2132 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2133 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2134 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2135 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2136 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2138 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2142 static_cast<CShuffleDataType*
>(p_shared_0),
2143 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2146 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2172 const auto c_thread_mtx_on_block =
2173 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2175 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2176 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2178 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2184 const auto m_thread_data_on_block_idx =
2185 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2188 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2194 const auto n_thread_data_on_block_idx =
2195 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2202 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2203 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2206 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2215 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2220 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2223 m_thread_data_on_block_idx[
I1],
2224 n_thread_data_on_block_idx[
I1],
2225 m_thread_data_on_block_idx[
I2],
2226 n_thread_data_on_block_idx[
I2],
2227 m_thread_data_on_block_idx[
I3],
2228 m_thread_data_on_block_idx[
I4],
2229 m_thread_data_on_block_idx[
I5],
2230 n_thread_data_on_block_idx[
I3]),
2236 CElementwiseOperation,
2237 CGlobalMemoryDataOperation,
2239 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2241 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2242 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2246 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2247 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2250 CShuffleBlockTransferScalarPerVector_NPerBlock,
2253 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2255 c_grid_desc_mblock_mperblock_nblock_nperblock,
2260 constexpr auto sfc_c_vgpr =
2271 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2273 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2284 constexpr auto sfc_c_global =
2288 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2290 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2292 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2294 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2301 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2302 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2304 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2305 c_shuffle_block_buf);
2311 c_shuffle_block_copy_lds_to_global.Run(
2312 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2313 c_shuffle_block_buf,
2314 c_grid_desc_mblock_mperblock_nblock_nperblock,
2317 if constexpr(access_id < num_access - 1)
2319 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2322 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2323 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2329 template <
bool HasMainKBlockLoop,
2332 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2333 const AScaleDataType* p_a_scale_grid,
2334 const BDataType* p_b_grid,
2335 const BScaleDataType* p_b_scale_grid,
2336 CDataType* p_c_grid,
2339 const Problem& problem)
2342 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2344 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2346 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2347 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2349 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2353 const auto Padded_Scale_M =
2377 Run_2Lds<
decltype(a_grid_desc_ak0_m_ak1),
2378 decltype(a_scale_grid_desc_am_ak),
2379 decltype(b_grid_desc_bk0_n_bk1),
2380 decltype(b_scale_grid_desc_bn_ak),
2381 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2383 CGlobalMemoryDataOperation,
2392 a_grid_desc_ak0_m_ak1,
2393 a_scale_grid_desc_am_ak,
2394 b_grid_desc_bk0_n_bk1,
2395 b_scale_grid_desc_bn_ak,
2396 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
f6_pk_t< f6_t, 16 > f6x16_pk_t
Definition data_type.hpp:180
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
f6_pk_t< bf6_t, 32 > bf6x32_pk_t
Definition data_type.hpp:183
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
f6_pk_t< f6_t, 32 > f6x32_pk_t
Definition data_type.hpp:181
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
f6_pk_t< bf6_t, 16 > bf6x16_pk_t
Definition data_type.hpp:182
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:685
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:741
__host__ Argument(const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:686
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:730
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:725
const AScaleDataType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:736
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:739
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:743
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:744
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:735
const BScaleDataType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:738
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:737
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:742
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:653
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:669
index_t StrideScaleA
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:668
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:676
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:667
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:664
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:673
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:680
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:665
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:674
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:624
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:678
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:666
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:675
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:679
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:671
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:670
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:672
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:677
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:806
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:750
index_t a_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:807
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:805
index_t b_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:808
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:809
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:156
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1811
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1321
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:202
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BPackedSize static constexpr index_t BPackedSize
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:190
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:403
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:230
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1374
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1123
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:2332
static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:812
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BK1Number static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:173
__host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:264
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:224
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:204
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::is_single_rate_mfma static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:176
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:254
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:214
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I0 static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:158
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I8 static constexpr auto I8
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:166
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I3 static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:161
static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1049
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::MXdlPack static constexpr auto MXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:179
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1336
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KXdlPack static constexpr auto KXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:181
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:219
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1093
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:242
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BK0Number static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:171
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1884
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::AK0Number static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:170
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::scale_pack_size_b static constexpr index_t scale_pack_size_b
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1356
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I7 static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:165
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::mx_scale_t e8m0_bexp_t mx_scale_t
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1354
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I2 static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:160
__host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:543
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::lcm_AK1_BK1 static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:175
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1064
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:562
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I1 static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:159
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:288
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I6 static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:164
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::scale_pack_size_a static constexpr index_t scale_pack_size_a
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1355
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:249
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::APackedSize static constexpr index_t APackedSize
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:189
__host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:553
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I9 static constexpr auto I9
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:167
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I4 static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:162
static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1328
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::NXdlPack static constexpr auto NXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:180
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::AK1Number static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:172
static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:933
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:209
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I5 static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:163
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:236
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1351
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KPack static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:192
ck::GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::is_scale_mfma static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:177
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129