FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
#include <fmha_fwd_pagedkv_kernel.hpp>
Classes | |
| struct | t2s |
| struct | t2s< float > |
| struct | t2s< ck_tile::fp16_t > |
| struct | t2s< ck_tile::bf16_t > |
| struct | t2s< ck_tile::fp8_t > |
| struct | t2s< ck_tile::bf8_t > |
| struct | FmhaFwdEmptyKargs |
| struct | FmhaFwdCommonKargs |
| struct | FmhaFwdLogitsSoftCapKargs |
| struct | FmhaFwdCommonBiasKargs |
| struct | FmhaFwdBatchModeBiasKargs |
| struct | FmhaFwdAlibiKargs |
| struct | FmhaFwdMaskKargs |
| struct | FmhaFwdFp8StaticQuantKargs |
| struct | FmhaFwdCommonLSEKargs |
| struct | FmhaFwdSkipMinSeqlenQKargs |
| struct | CommonPageBlockTableKargs |
| struct | GroupModePageBlockTableKargs |
| struct | CacheBatchIdxKargs |
| struct | FmhaFwdBatchModeKargs |
| struct | FmhaFwdGroupModeKargs |
| struct | BlockIndices |
Public Types | |
| using | FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_> |
| using | EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_> |
| using | QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType> |
| using | KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType> |
| using | VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType> |
| using | BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType> |
| using | LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType> |
| using | ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType> |
| using | SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType> |
| using | VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout> |
| using | AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant> |
| using | FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask> |
| using | Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs> |
Public Member Functions | |
| CK_TILE_DEVICE void | operator() (Kargs kargs) const |
Static Public Member Functions | |
| static CK_TILE_HOST std::string | GetName () |
| template<bool Cond = !kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type) |
| template<bool Cond = !kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type) |
| template<bool Cond = kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) |
| template<bool Cond = kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) |
| static CK_TILE_HOST void | PrintParameters (const Kargs &kargs, int num_batches) |
| static CK_TILE_HOST constexpr auto | GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k) |
| static CK_TILE_DEVICE constexpr auto | GetTileIndex (const Kargs &kargs) |
| static CK_TILE_HOST dim3 | BlockSize () |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr ck_tile::index_t | kBlockSize = FmhaPipeline::kBlockSize |
| static constexpr ck_tile::index_t | kBlockPerCu = FmhaPipeline::kBlockPerCu |
| static constexpr ck_tile::index_t | kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu |
| static constexpr bool | kIsGroupMode = FmhaPipeline::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = FmhaPipeline::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = FmhaPipeline::kPadHeadDimV |
| static constexpr bool | kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap |
| static constexpr auto | BiasEnum = FmhaPipeline::BiasEnum |
| static constexpr bool | kStoreLSE = FmhaPipeline::kStoreLSE |
| static constexpr bool | kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant |
| static constexpr bool | kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ |
| static constexpr bool | kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV |
| static constexpr bool | kHasMask = FmhaMask::IsMasking |
| static constexpr bool | kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy |
Member Typedef Documentation
◆ AttentionVariant
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant> |
◆ BiasDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType> |
◆ EpiloguePipeline
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_> |
◆ FmhaMask
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask> |
◆ FmhaPipeline
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_> |
◆ Kargs
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs> |
◆ KDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType> |
◆ LSEDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType> |
◆ ODataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType> |
◆ QDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType> |
◆ SaccDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType> |
◆ VDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType> |
◆ VLayout
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdPagedKVKernel< FmhaPipeline_, EpiloguePipeline_ >::VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout> |
Member Function Documentation
◆ BlockSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestatic |
◆ GetName()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestatic |
◆ GetSmemSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ GetTileIndex()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ GridSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ MakeKargs() [1/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargs() [2/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargsImpl() [1/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargsImpl() [2/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ operator()()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inline |
FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20
◆ PrintParameters()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestatic |
Member Data Documentation
◆ BiasEnum
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kBlockPerCu
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kBlockPerCuInput
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kBlockSize
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kDoFp8StaticQuant
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kHasMask
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kIsGroupMode
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kIsPagedKV
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kSkipMinSeqlenQ
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kStoreLSE
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kUseAsyncCopy
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
The documentation for this struct was generated from the following file: