smfmac_xdlops_gemm.hpp Source File#
smfmac_xdlops_gemm.hpp
Go to the documentation of this file.
249 MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
337 __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; }
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
@ smfmac_f32_16x16x32bf16
Definition smfmac_xdlops_gemm.hpp:16
@ smfmac_f32_32x32x16bf16
Definition smfmac_xdlops_gemm.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
Definition smfmac_xdlops_gemm.hpp:140
__host__ __device__ constexpr SmfmacSelector()
Definition smfmac_xdlops_gemm.hpp:174
static constexpr index_t GetKPerXdlops()
Definition smfmac_xdlops_gemm.hpp:200
static constexpr auto GetSmfmac()
static constexpr auto selected_smfmac
Definition smfmac_xdlops_gemm.hpp:171
static constexpr index_t GetK1PerXdlops()
Definition smfmac_xdlops_gemm.hpp:206
static __device__ auto GetLaneId()
Definition smfmac_xdlops_gemm.hpp:337
static __device__ constexpr index_t GetNumBlks()
Definition smfmac_xdlops_gemm.hpp:226
static constexpr auto K0PerXdlops
Definition smfmac_xdlops_gemm.hpp:424
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition smfmac_xdlops_gemm.hpp:407
static __device__ constexpr index_t GetWaveSize()
Definition smfmac_xdlops_gemm.hpp:322
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, const Idx &idx, FloatC &p_c_thread) const
Definition smfmac_xdlops_gemm.hpp:326
static constexpr auto K1PerXdlops
Definition smfmac_xdlops_gemm.hpp:423
__host__ static __device__ constexpr auto GetCM0M1M2NThreadBlkLengths()
Definition smfmac_xdlops_gemm.hpp:426
static __device__ auto GetBlkIdx()
Definition smfmac_xdlops_gemm.hpp:339
__host__ __device__ constexpr SparseXdlopsGemm()
Definition smfmac_xdlops_gemm.hpp:234
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition smfmac_xdlops_gemm.hpp:394
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition smfmac_xdlops_gemm.hpp:358
static constexpr auto smfmac_instr
Definition smfmac_xdlops_gemm.hpp:420
static constexpr auto KPerXdlops
Definition smfmac_xdlops_gemm.hpp:422
static __device__ constexpr index_t GetNumXdlops()
Definition smfmac_xdlops_gemm.hpp:228
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition smfmac_xdlops_gemm.hpp:249
static __device__ constexpr index_t GetRegSizePerXdlops()
Definition smfmac_xdlops_gemm.hpp:317
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition smfmac_xdlops_gemm.hpp:376
__host__ static __device__ constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition smfmac_xdlops_gemm.hpp:281
Definition amd_smfmac.hpp:34
Definition amd_smfmac.hpp:10
Definition amd_smfmac.hpp:78
Definition amd_smfmac.hpp:56
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:89
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:82
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:85
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:87
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:92
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC ®_c) const
Definition smfmac_xdlops_gemm.hpp:100
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:83
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:84
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:91
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:86
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:88
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:90
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:33
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:29
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:26
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:32
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:36
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:30
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:27
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:35
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC ®_c) const
Definition smfmac_xdlops_gemm.hpp:44
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:34
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:31
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:28
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:118
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:110
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:111
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:120
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:119
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:112
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:116
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:113
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:115
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC ®_c) const
Definition smfmac_xdlops_gemm.hpp:128
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:114
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:117
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:64
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:61
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:54
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:62
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:58
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:57
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:60
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:59
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:55
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:56
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC ®_c) const
Definition smfmac_xdlops_gemm.hpp:72
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:63
Definition smfmac_xdlops_gemm.hpp:21
Definition functional2.hpp:33