XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference

XdlopsGemm&lt; base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma &gt; Struct Template Reference#

Composable Kernel: ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference
ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference

#include <xdlops_gemm.hpp>

Public Types

using CIndex = MultiIndex<2>
using CIndex4D = MultiIndex<4>

Public Member Functions

__host__ __device__ constexpr XdlopsGemm ()
template<class FloatA, class FloatB, class FloatC>
__device__ void Run (const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
template<index_t OpselA, index_t OpselB, class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
__device__ void Run (const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const

Static Public Member Functions

static __device__ constexpr index_t GetNumBlks ()
static __device__ constexpr index_t GetNumXdlops ()
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
template<typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2 (const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
__device__ static __host__ constexpr index_t GetRegSizePerXdlops ()
static __device__ constexpr index_t GetWaveSize ()
static __device__ auto GetLaneId ()
static __device__ auto GetBlkIdx ()
template<bool SwizzleA>
static __device__ auto GetGfx11InputBlkIdx ()
__host__ static __device__ auto CalculateAThreadOriginDataIndex ()
__host__ static __device__ auto CalculateBThreadOriginDataIndex ()
static __device__ CIndex GetBeginOfThreadBlk (index_t xdlops_i, index_t blk_i)
static __device__ CIndex4D GetBeginOfThreadBlk4D (index_t, index_t)
__host__ static __device__ constexpr auto GetCM0M1M2NThreadBlkLengths ()

Static Public Attributes

static constexpr auto I0 = Number<0>{}
static constexpr auto I1 = Number<1>{}
static constexpr auto I2 = Number<2>{}
static constexpr auto I3 = Number<3>{}
static constexpr auto I4 = Number<4>{}
static constexpr auto I5 = Number<5>{}
static constexpr bool is_single_rate_mfma
static constexpr auto mfma
static constexpr auto mfma_instr = mfma.selected_mfma
static constexpr auto KPerXdlops = mfma.GetKPerXdlops()
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops()
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops

Member Typedef Documentation

◆ CIndex

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
using ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CIndex = MultiIndex<2>

◆ CIndex4D

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
using ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CIndex4D = MultiIndex<4>

Constructor & Destructor Documentation

◆ XdlopsGemm()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ __device__ constexpr ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::XdlopsGemm ( )
inlineconstexpr

Member Function Documentation

◆ CalculateAThreadOriginDataIndex()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CalculateAThreadOriginDataIndex ( )
inlinestatic

◆ CalculateBThreadOriginDataIndex()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CalculateBThreadOriginDataIndex ( )
inlinestatic

◆ GetBeginOfThreadBlk()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ CIndex ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBeginOfThreadBlk ( index_t xdlops_i,
index_t blk_i )
inlinestatic

◆ GetBeginOfThreadBlk4D()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ CIndex4D ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBeginOfThreadBlk4D ( index_t ,
index_t  )
inlinestatic

◆ GetBlkIdx()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBlkIdx ( )
inlinestatic

◆ GetCM0M1M2NThreadBlkLengths()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static __device__ constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetCM0M1M2NThreadBlkLengths ( )
inlinestaticconstexpr

◆ GetGfx11InputBlkIdx()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<bool SwizzleA>
__device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetGfx11InputBlkIdx ( )
inlinestatic

◆ GetLaneId()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetLaneId ( )
inlinestatic

◆ GetNumBlks()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ constexpr index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetNumBlks ( )
inlinestaticconstexpr

◆ GetNumXdlops()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ constexpr index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetNumXdlops ( )
inlinestaticconstexpr

◆ GetRegSizePerXdlops()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ static __host__ constexpr index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetRegSizePerXdlops ( )
inlinestaticconstexpr

◆ GetWaveSize()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ constexpr index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetWaveSize ( )
inlinestaticconstexpr

◆ MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2 ( const CDesc_G_M0_N0_M1_N1_M2_N2 & c_desc_g_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2 ( const CDesc_M0_N0_M1_N1_M2_N2 & c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 ( const CDesc_M0_N0_M1_N1_M2_N2 & c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ static __device__ constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 ( const CDesc_M0_N0_M1_N1_M2_N2 & c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ Run() [1/2]

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<class FloatA, class FloatB, class FloatC>
__device__ void ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::Run ( const FloatA & p_a_wave,
const FloatB & p_b_wave,
FloatC & p_c_thread ) const
inline

◆ Run() [2/2]

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<index_t OpselA, index_t OpselB, class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
__device__ void ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::Run ( const FloatA & p_a_wave,
const ScaleA & a_scale_thread,
const FloatB & p_b_wave,
const ScaleB & b_scale_thread,
FloatC & p_c_thread ) const
inline

Member Data Documentation

◆ I0

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I0 = Number<0>{}
staticconstexpr

◆ I1

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I1 = Number<1>{}
staticconstexpr

◆ I2

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I2 = Number<2>{}
staticconstexpr

◆ I3

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I3 = Number<3>{}
staticconstexpr

◆ I4

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I4 = Number<4>{}
staticconstexpr

◆ I5

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I5 = Number<5>{}
staticconstexpr

◆ is_single_rate_mfma

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
bool ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::is_single_rate_mfma
staticconstexpr
Initial value:
=
KPack <= 4) ||
? true
: false
static constexpr value_type value
Definition utility/integral_constant.hpp:13

◆ K0PerXdlops

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::K0PerXdlops = KPerXdlops / K1PerXdlops
staticconstexpr

◆ K1PerXdlops

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::K1PerXdlops = mfma.GetK1PerXdlops()
staticconstexpr

◆ KPerXdlops

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::KPerXdlops = mfma.GetKPerXdlops()
staticconstexpr

◆ mfma

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::mfma
staticconstexpr
Initial value:
= MfmaSelector<base_type,
MPerXdlops,
NPerXdlops,
additional_type,
static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:177
static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:178
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208

◆ mfma_instr

template<typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::mfma_instr = mfma.selected_mfma
staticconstexpr

The documentation for this struct was generated from the following file: