device_normalization_bwd_data_impl.hpp Source File#
device_normalization_bwd_data_impl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
__global__ void kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M_K dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition device_normalization_bwd_data_impl.hpp:29
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_bwd_data.hpp:49
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_normalization_bwd_data.hpp:22
Definition device_normalization_bwd_data_impl.hpp:222
DXDataType * p_dx_
Definition device_normalization_bwd_data_impl.hpp:276
GridDesc_M_K mean_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:293
const XDataType * p_x_
Definition device_normalization_bwd_data_impl.hpp:272
index_t MRaw_
Definition device_normalization_bwd_data_impl.hpp:298
const MeanInvStdDataType * p_mean_
Definition device_normalization_bwd_data_impl.hpp:274
GridDesc_M_K gamma_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:292
const DYDataType * p_dy_
Definition device_normalization_bwd_data_impl.hpp:271
GridDesc_M_K inv_std_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:294
Argument(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const DYDataType *p_dy, const XDataType *p_x, const GammaDataType *p_gamma, const MeanInvStdDataType *p_mean, const MeanInvStdDataType *p_invStd, DXDataType *p_dx)
Definition device_normalization_bwd_data_impl.hpp:223
int numBlockTileIteration_
Definition device_normalization_bwd_data_impl.hpp:286
GridDesc_M_K dy_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:290
const MeanInvStdDataType * p_invStd_
Definition device_normalization_bwd_data_impl.hpp:275
std::vector< index_t > gammaStrides_
Definition device_normalization_bwd_data_impl.hpp:281
bool isSweeponce_
Definition device_normalization_bwd_data_impl.hpp:297
std::vector< index_t > lengths_
Definition device_normalization_bwd_data_impl.hpp:278
std::vector< index_t > invStdStrides_
Definition device_normalization_bwd_data_impl.hpp:283
size_t gridSize_
Definition device_normalization_bwd_data_impl.hpp:287
GridDesc_M_K dx_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:295
std::vector< index_t > dxStrides_
Definition device_normalization_bwd_data_impl.hpp:284
std::vector< index_t > xStrides_
Definition device_normalization_bwd_data_impl.hpp:280
const GammaDataType * p_gamma_
Definition device_normalization_bwd_data_impl.hpp:273
index_t KRaw_
Definition device_normalization_bwd_data_impl.hpp:299
std::vector< index_t > meanStrides_
Definition device_normalization_bwd_data_impl.hpp:282
std::vector< index_t > dyStrides_
Definition device_normalization_bwd_data_impl.hpp:279
GridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:291
Definition device_normalization_bwd_data_impl.hpp:303
auto KernelSelector(bool isSweepOnce)
Definition device_normalization_bwd_data_impl.hpp:304
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_bwd_data_impl.hpp:323
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_bwd_data_impl.hpp:347
Definition device_normalization_bwd_data_impl.hpp:88
static constexpr index_t MeanInvStdSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:92
static constexpr index_t K_BlockTileSize
Definition device_normalization_bwd_data_impl.hpp:122
static constexpr index_t NumInvariantDim
Definition device_normalization_bwd_data_impl.hpp:120
GridwiseNormalizationBwdData_mk_to_mk< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DXDstVectorDim, DXDstVectorSize, true > GridwiseNormalizationBwdDataSweepOnce
Definition device_normalization_bwd_data_impl.hpp:196
static constexpr index_t XSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:90
static constexpr index_t DXDstVectorDim
Definition device_normalization_bwd_data_impl.hpp:93
static constexpr index_t M_BlockTileSize
Definition device_normalization_bwd_data_impl.hpp:121
std::string GetTypeString() const override
Definition device_normalization_bwd_data_impl.hpp:447
static constexpr index_t DYSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:89
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_bwd_data_impl.hpp:442
static constexpr bool reduceAllDim
Definition device_normalization_bwd_data_impl.hpp:124
static auto Make2dDescriptor(const std::vector< index_t > &lengths, const std::vector< index_t > &strides, int numBlockTileIteration)
Definition device_normalization_bwd_data_impl.hpp:127
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_gamma, const void *p_mean, const void *p_invStd, void *p_dx) override
Definition device_normalization_bwd_data_impl.hpp:406
decltype(Make2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_bwd_data_impl.hpp:169
GridwiseNormalizationBwdData_mk_to_mk< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DXDstVectorDim, DXDstVectorSize, false > GridwiseNormalizationBwdDataGeneric
Definition device_normalization_bwd_data_impl.hpp:171
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_bwd_data_impl.hpp:385
static constexpr index_t GammaSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:91
bool IsVectorDimSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_data_impl.hpp:355