batched_gemm_kernel.hpp Source File#
batched_gemm_kernel.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
The Batched GEMM kernel host arguments.
Definition batched_gemm_kernel.hpp:20
ck_tile::index_t batch_stride_B
Definition batched_gemm_kernel.hpp:55
ck_tile::index_t batch_stride_A
Definition batched_gemm_kernel.hpp:54
ck_tile::index_t batch_stride_E
Definition batched_gemm_kernel.hpp:56
CK_TILE_HOST BatchedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t batch_stride_A_, ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_)
Definition batched_gemm_kernel.hpp:21
ck_tile::index_t batch_count
Definition batched_gemm_kernel.hpp:57
ALayout and ADataType are expected to be scalars, not a tuple.
Definition batched_gemm_kernel.hpp:99
index_t batch_stride_E
Definition batched_gemm_kernel.hpp:102
index_t batch_count
Definition batched_gemm_kernel.hpp:103
index_t batch_stride_A
Definition batched_gemm_kernel.hpp:100
index_t batch_stride_B
Definition batched_gemm_kernel.hpp:101
Definition batched_gemm_kernel.hpp:62
static constexpr index_t kBlockSize
Definition batched_gemm_kernel.hpp:67
static CK_TILE_HOST auto IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs &kargs) -> bool
Definition batched_gemm_kernel.hpp:164
static CK_TILE_HOST constexpr BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs &hostArgs)
Definition batched_gemm_kernel.hpp:138
BatchedGemmKernelArgs KernelArgs
Definition batched_gemm_kernel.hpp:106
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
Definition batched_gemm_kernel.hpp:120
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition batched_gemm_kernel.hpp:70
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition batched_gemm_kernel.hpp:75
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition batched_gemm_kernel.hpp:69
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition batched_gemm_kernel.hpp:71
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition batched_gemm_kernel.hpp:76
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition batched_gemm_kernel.hpp:65
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition batched_gemm_kernel.hpp:158
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition batched_gemm_kernel.hpp:80
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
Definition batched_gemm_kernel.hpp:204
static CK_TILE_HOST auto BlockSize() -> dim3
Definition batched_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition batched_gemm_kernel.hpp:79
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition batched_gemm_kernel.hpp:74
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition batched_gemm_kernel.hpp:81
static CK_TILE_HOST auto GetName() -> const std::string
Definition batched_gemm_kernel.hpp:108
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition universal_gemm_kernel.hpp:33
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:71
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:61
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:88
index_t N
GEMM's N dimension size.
Definition universal_gemm_kernel.hpp:98
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:90
index_t M
GEMM's M dimension size.
Definition universal_gemm_kernel.hpp:96
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202