40 MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
42 NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
45 MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
47 NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
50 WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
52 WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
55 MPerBlock * NPerBlock * KPerBlock / (BlockSize /
WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
59 printf(
" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
69 printf(
" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
70 "%d, %d\n C MFMA inst: %d\n",
87 typename AMmaTileDesc,
88 typename BMmaTileDesc,
97 bool TransposeC =
false,
99 KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
143 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
163 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
170 const auto waveId_m = wave_idx[
I0];
172 const auto xdlops_a_idx =
xdlops_gemm.CalculateAThreadOriginDataIndex();
174 return make_tuple(0, waveId_m, xdlops_a_idx[
I1], KPack * xdlops_a_idx[
I0]);
181 const auto waveId_n = wave_idx[
I1];
183 const auto xdlops_b_idx =
xdlops_gemm.CalculateBThreadOriginDataIndex();
185 return make_tuple(0, waveId_n, xdlops_b_idx[
I1], KPack * xdlops_b_idx[
I0]);
188 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
189 __device__
static auto
194 const auto waveId_m = wave_idx[
I0];
195 const auto waveId_n = wave_idx[
I1];
197 const auto blk_idx =
xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
209 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
211 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
217 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
218 __device__
static auto
223 const auto waveId_m = wave_idx[
I0];
224 const auto waveId_n = wave_idx[
I1];
226 const auto blk_idx =
xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
229 m0, n0, waveId_m, waveId_n, blk_idx[
I0], blk_idx[
I1], blk_idx[
I2], blk_idx[
I3]);
239#if defined(__HIP_DEVICE_COMPILE__)
240 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
241 "wrong! Desc should be known at compile-time");
244 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
246 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
255 constexpr auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
257 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
258 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
259 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
260 constexpr auto N = c_m0_m1_m2_n_tblk_lens[
I3];
269 constexpr auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
271 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
272 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
273 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
274 constexpr auto N = c_m0_m1_m2_n_tblk_lens[
I3];
282 constexpr auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
284 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
285 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
286 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
287 constexpr auto N = c_m0_m1_m2_n_tblk_lens[
I3];
296 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
304 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
310 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
318 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
323 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
332 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
333 c_block_desc_g_m0_n0_m1_n1_m2_n2);
336 template <
typename CGr
idDesc_M_N>
337 __host__ __device__
static constexpr auto
340 const auto M = c_grid_desc_m_n.GetLength(
I0);
341 const auto N = c_grid_desc_m_n.GetLength(
I1);
350 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
353 template <
typename CGr
idDesc_G_M_N>
354 __host__ __device__
static constexpr auto
357 const auto G = c_grid_desc_g_m_n.GetLength(
I0);
358 const auto M = c_grid_desc_g_m_n.GetLength(
I1);
359 const auto N = c_grid_desc_g_m_n.GetLength(
I2);
369 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
370 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
376 constexpr auto num_ds_read_inst =
378 constexpr auto num_ds_write_inst =
381 constexpr auto num_buffer_load_inst =
386 constexpr auto num_issue = num_buffer_load_inst;
390 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
391 __builtin_amdgcn_sched_group_barrier(
392 0x100, num_ds_read_inst / num_buffer_load_inst, 0);
393 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
394 __builtin_amdgcn_sched_group_barrier(
395 0x200, num_ds_write_inst / num_buffer_load_inst, 0);
396 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
397 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
398 __builtin_amdgcn_sched_group_barrier(
399 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0);
403 template <index_t stage>
412 constexpr auto num_ds_read_inst =
414 constexpr auto num_ds_write_inst =
419 constexpr auto num_issue = num_ds_write_inst;
423 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
424 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
425 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
426 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
427 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
428 __builtin_amdgcn_sched_group_barrier(
429 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0);
430 __builtin_amdgcn_sched_group_barrier(
431 0x008, num_mfma_inst / num_ds_write_inst - 3, 0);
439 constexpr auto num_ds_read_inst =
443 constexpr auto num_issue = num_ds_read_inst;
447 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
448 __builtin_amdgcn_sched_group_barrier(
449 0x008, num_mfma_inst / num_ds_read_inst, 0);
456 template <
bool HasMainLoop,
460 typename ABlockTransfer,
461 typename AGridBuffer,
462 typename ABlockBuffer,
463 typename ABlockTransferStep,
466 typename BBlockTransfer,
467 typename BGridBuffer,
468 typename BBlockBuffer,
469 typename BBlockTransferStep,
470 typename CThreadBuffer>
471 __device__
void Run(
const AGridDesc& a_grid_desc,
472 const ABlockDesc& a_block_desc,
473 ABlockTransfer& a_blockwise_copy,
474 const AGridBuffer& a_grid_buf,
475 ABlockBuffer& a_block_buf,
476 const ABlockTransferStep& a_block_copy_step,
477 const BGridDesc& b_grid_desc,
478 const BBlockDesc& b_block_desc,
479 BBlockTransfer& b_blockwise_copy,
480 const BGridBuffer& b_grid_buf,
481 BBlockBuffer& b_block_buf,
482 const BBlockTransferStep& b_block_copy_step,
483 CThreadBuffer& c_thread_buf,
486 __builtin_amdgcn_sched_barrier(0);
502 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
503 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
505 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
506 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
508 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
509 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
533 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
534 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
536 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
537 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
539 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
540 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
543 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
544 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
546 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
547 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
550 c_thread_buf.Clear();
553 if constexpr(HasMainLoop)
571 a_block_buf.At(PongP1{}),
574 a_thread_bufs(PongP1{}));
578 b_block_buf.At(PongP1{}),
581 b_thread_bufs(PongP1{}));
586 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
587 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
589 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
590 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
592 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
593 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
602 a_thread_vec.template AsType<FloatAB>()(ik) =
605 b_thread_vec.template AsType<FloatAB>()(ik) =
610 using mfma_input_type =
617 a_thread_vec.template AsType<mfma_input_type>(),
618 b_thread_vec.template AsType<mfma_input_type>(),
625 __builtin_amdgcn_sched_barrier(0);
639 a_block_buf.At(PongP2{}),
642 a_thread_bufs(PongP2{}));
646 b_block_buf.At(PongP2{}),
649 b_thread_bufs(PongP2{}));
654 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
655 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
657 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
658 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
660 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
661 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
670 a_thread_vec.template AsType<FloatAB>()(ik) =
673 b_thread_vec.template AsType<FloatAB>()(ik) =
678 using mfma_input_type =
685 a_thread_vec.template AsType<mfma_input_type>(),
686 b_thread_vec.template AsType<mfma_input_type>(),
693 __builtin_amdgcn_sched_barrier(0);
696 }
while(i < (num_loop - 3));
700 if constexpr(TailNum == 3)
713 a_block_buf.At(PongP1{}),
716 a_thread_bufs(PongP1{}));
720 b_block_buf.At(PongP1{}),
723 b_thread_bufs(PongP1{}));
728 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
729 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
738 a_thread_vec.template AsType<FloatAB>()(ik) =
741 b_thread_vec.template AsType<FloatAB>()(ik) =
746 using mfma_input_type =
752 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
753 b_thread_vec.template AsType<mfma_input_type>(),
760 __builtin_amdgcn_sched_barrier(0);
774 a_block_buf.At(PongP2{}),
777 a_thread_bufs(PongP2{}));
781 b_block_buf.At(PongP2{}),
784 b_thread_bufs(PongP2{}));
796 a_thread_vec.template AsType<FloatAB>()(ik) =
799 b_thread_vec.template AsType<FloatAB>()(ik) =
804 using mfma_input_type =
810 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
811 b_thread_vec.template AsType<mfma_input_type>(),
818 __builtin_amdgcn_sched_barrier(0);
827 a_thread_vec.template AsType<FloatAB>()(ik) =
830 b_thread_vec.template AsType<FloatAB>()(ik) =
835 using mfma_input_type =
841 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
842 b_thread_vec.template AsType<mfma_input_type>(),
849 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0);
850 __builtin_amdgcn_sched_barrier(0);
852 else if constexpr(TailNum == 2)
865 a_block_buf.At(PongP1{}),
868 a_thread_bufs(PongP1{}));
872 b_block_buf.At(PongP1{}),
875 b_thread_bufs(PongP1{}));
887 a_thread_vec.template AsType<FloatAB>()(ik) =
890 b_thread_vec.template AsType<FloatAB>()(ik) =
895 using mfma_input_type =
901 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
902 b_thread_vec.template AsType<mfma_input_type>(),
909 __builtin_amdgcn_sched_barrier(0);
924 a_thread_vec.template AsType<FloatAB>()(ik) =
927 b_thread_vec.template AsType<FloatAB>()(ik) =
932 using mfma_input_type =
938 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
939 b_thread_vec.template AsType<mfma_input_type>(),
946 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0);
947 __builtin_amdgcn_sched_barrier(0);
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition blockwise_gemm_pipeline_xdlops.hpp:34
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::B_LDS_Write_Inst_Num static constexpr index_t B_LDS_Write_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:46
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::A_LDS_Read_Inst_Num static constexpr index_t A_LDS_Read_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:49
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::B_LDS_Read_Inst_Num static constexpr index_t B_LDS_Read_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::A_LDS_Write_Inst_Num static constexpr index_t A_LDS_Write_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:44
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::C_MFMA_Inst_Num static constexpr index_t C_MFMA_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:54
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::A_Buffer_Load_Inst_Num static constexpr index_t A_Buffer_Load_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:39
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::WaveSize static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops.hpp:37
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::B_Buffer_Load_Inst_Num static constexpr index_t B_Buffer_Load_Inst_Num
Definition blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition blockwise_gemm_pipeline_xdlops.hpp:57
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::WaveNumN static constexpr index_t WaveNumN
Definition blockwise_gemm_pipeline_xdlops.hpp:36
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)>::WaveNumM static constexpr index_t WaveNumM
Definition blockwise_gemm_pipeline_xdlops.hpp:35
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops.hpp:105
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops.hpp:980
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops.hpp:111
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops.hpp:338
static __device__ constexpr auto HotLoopScheduler()
Definition blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops.hpp:117
static constexpr index_t A_K0
Definition blockwise_gemm_pipeline_xdlops.hpp:115
static constexpr auto b_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:961
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:321
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:267
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops.hpp:152
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops.hpp:253
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops.hpp:113
static __device__ constexpr auto TailScheduler()
Definition blockwise_gemm_pipeline_xdlops.hpp:404
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition blockwise_gemm_pipeline_xdlops.hpp:991
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops.hpp:970
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition blockwise_gemm_pipeline_xdlops.hpp:232
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops.hpp:453
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops.hpp:355
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops.hpp:177
AThreadCopy a_thread_copy_
Definition blockwise_gemm_pipeline_xdlops.hpp:990
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops.hpp:294
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops.hpp:308
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_pipeline_xdlops.hpp:109
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops.hpp:124
static constexpr auto I3
Definition blockwise_gemm_pipeline_xdlops.hpp:107
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops.hpp:118
static constexpr auto I2
Definition blockwise_gemm_pipeline_xdlops.hpp:106
static constexpr index_t B_K0
Definition blockwise_gemm_pipeline_xdlops.hpp:116
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops.hpp:166
static constexpr index_t KPerThread
Definition blockwise_gemm_pipeline_xdlops.hpp:123
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_pipeline_xdlops.hpp:154
static constexpr auto a_thread_desc_
Definition blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr index_t NWaves
Definition blockwise_gemm_pipeline_xdlops.hpp:112
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_pipeline_xdlops.hpp:150
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops.hpp:120
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops.hpp:219
BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, A_K1, B_K1, A_K1, B_K1, KPack, KPack, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops.hpp:126
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops.hpp:471
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_pipeline_xdlops.hpp:235
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops.hpp:190
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
Definition functional2.hpp:33
Definition dtype_vector.hpp:10