GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize > Struct Template Reference

GridwiseMultiblockBatchNormForward&lt; XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize &gt; Struct Template Reference#

Composable Kernel: ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize > Struct Template Reference
ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize > Struct Template Reference

#include <gridwise_multiblock_batchnorm_forward.hpp>

Public Types

using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>
using ThreadBufferDimAccessOrder
using ThreadClusterArrangeOrder
using ThreadReduceSrcDesc_M_K
using ThreadReduceDstDesc_M
using ThreadReduceSrcDesc_M_1
using ThreadwiseWelford1
using ThreadwiseWelford2
using BlockwiseWelford1
using BlockwiseWelford2
using PassThroughOp = tensor_operation::element_wise::PassThrough

Static Public Member Functions

static __device__ void Run (const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const MeanVarCountGridDesc_M_G &mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)

Static Public Attributes

static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0)
static constexpr auto thread_cluster_desc
static constexpr auto I0 = Number<0>{}
static constexpr auto I1 = Number<1>{}
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize

Member Typedef Documentation

◆ BlockwiseWelford1

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::BlockwiseWelford1
Initial value:
BlockwiseWelford<AccDataType,
BlockSize,
false>
Definition blockwise_welford.hpp:25
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_batchnorm_forward.hpp:123
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_batchnorm_forward.hpp:128

◆ BlockwiseWelford2

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::BlockwiseWelford2
Initial value:

◆ PassThroughOp

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::PassThroughOp = tensor_operation::element_wise::PassThrough

◆ ThreadBufferDimAccessOrder

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadBufferDimAccessOrder
Initial value:
Definition utility/sequence.hpp:43
Definition utility/functional.hpp:100

◆ ThreadClusterArrangeOrder

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadClusterArrangeOrder

◆ ThreadClusterLengths_M_K

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>

◆ ThreadReduceDstDesc_M

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadReduceDstDesc_M
Initial value:
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211

◆ ThreadReduceSrcDesc_M_1

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadReduceSrcDesc_M_1

◆ ThreadReduceSrcDesc_M_K

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadReduceSrcDesc_M_K

◆ ThreadwiseWelford1

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadwiseWelford1

◆ ThreadwiseWelford2

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
using ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::ThreadwiseWelford2

Member Function Documentation

◆ Run()

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
__device__ void ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::Run ( const XYGridDesc_M_K & x_grid_desc_m_k,
const XYGridDesc_M_K & y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G & mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K & mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M & scale_grid_desc_m,
const ScaleBiasGridDesc_M & bias_grid_desc_m,
const MeanVarGridDesc_M & mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor & get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType *const __restrict__ p_x,
MeanVarDataType *const __restrict__ p_welford_mean,
MeanVarDataType *const __restrict__ p_welford_variance,
int32_t *const __restrict__ p_welford_count,
int32_t *const __restrict__ p_control,
const ScaleDataType *const __restrict__ p_scale,
const BiasDataType *const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType *const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType *const __restrict__ resultRunningMean,
MeanVarDataType *const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType *const __restrict__ resultSaveMean,
MeanVarDataType *const __restrict__ resultSaveInvVariance )
inlinestatic

Member Data Documentation

◆ I0

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
auto ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::I0 = Number<0>{}
staticconstexpr

◆ I1

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
auto ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::I1 = Number<1>{}
staticconstexpr

◆ K_BlockTileSize

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
index_t ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::K_BlockTileSize = KThreadClusterSize * KThreadSliceSize
staticconstexpr

◆ M_BlockTileSize

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
index_t ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::M_BlockTileSize = MThreadClusterSize * MThreadSliceSize
staticconstexpr

◆ reorder_thread_cluster

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
bool ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::reorder_thread_cluster = (XSrcYDstVectorDim == 0)
staticconstexpr

◆ thread_cluster_desc

template<typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_G, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XSrcYDstVectorDim, index_t XSrcVectorSize, index_t YDstVectorSize, index_t ScaleSrcVectorSize, index_t BiasSrcVectorSize, index_t MeanVarSrcDstVectorSize>
auto ck::GridwiseMultiblockBatchNormForward< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_G, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, GetReduceCountPerThreadFunctor, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >::thread_cluster_desc
staticconstexpr
Initial value:
=
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13

The documentation for this struct was generated from the following file: