#include <grouped_flatmm_kernel.hpp>
|
| CK_TILE_HOST | MaskedGroupedFlatmmHostArgs ()=default |
| CK_TILE_HOST | MaskedGroupedFlatmmHostArgs (index_t *M_indices_, index_t group_count_, index_t Max_M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr) |
◆ MaskedGroupedFlatmmHostArgs() [1/2]
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ MaskedGroupedFlatmmHostArgs() [2/2]
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
| CK_TILE_HOST ck_tile::MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor >::MaskedGroupedFlatmmHostArgs |
( |
index_t * | M_indices_, |
|
|
index_t | group_count_, |
|
|
index_t | Max_M_, |
|
|
index_t | N_, |
|
|
index_t | K_, |
|
|
const void * | a_ptr_, |
|
|
index_t | stride_A_, |
|
|
const void * | b_shuffle_ptr_, |
|
|
index_t | stride_B_, |
|
|
const std::array< const void *, NumDTensor > & | ds_ptr_, |
|
|
const std::array< index_t, NumDTensor > & | stride_Ds_, |
|
|
void * | c_ptr_, |
|
|
index_t | stride_C_, |
|
|
index_t | k_batch_, |
|
|
ScaleM | scale_m_ = nullptr, |
|
|
ScaleN | scale_n_ = nullptr ) |
|
inline |
◆ [union]
◆ a_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ b_shuffle_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ c_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ ds_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ e_ptr
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ group_count
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ k_batch
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ M_indices
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ scale_m
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ scale_n
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_A
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_B
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_C
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
◆ stride_Ds
template<class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
The documentation for this struct was generated from the following file: