device_batchnorm_forward_impl.hpp Source File

device_batchnorm_forward_impl.hpp Source File#

Composable Kernel: device_batchnorm_forward_impl.hpp Source File
device_batchnorm_forward_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <typename XDataType,
27 typename YDataType,
28 typename AccDataType,
29 typename ScaleDataType,
30 typename BiasDataType,
31 typename MeanVarDataType,
32 typename YElementwiseOp,
33 index_t Rank,
34 index_t NumBatchNormReduceDim,
35 bool UseMultiblockInK,
36 index_t BlockSize,
37 index_t MThreadClusterSize,
38 index_t KThreadClusterSize,
39 index_t MThreadSliceSize,
40 index_t KThreadSliceSize,
41 index_t XSrcYDstVectorDim,
42 index_t XSrcVectorSize,
43 index_t YDstVectorSize,
44 index_t ScaleSrcVectorSize,
45 index_t BiasSrcVectorSize,
46 index_t MeanVarSrcDstVectorSize>
48 YDataType,
49 AccDataType,
50 ScaleDataType,
51 BiasDataType,
52 MeanVarDataType,
53 YElementwiseOp,
54 Rank,
55 NumBatchNormReduceDim>
56{
57 static_assert(Rank <= 6, "Bigger Rank size is not supported!");
58 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
59 "Invalid thread cluster size assignments!");
60
61 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
62 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
63 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
64
65 static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
66
67 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
68 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
69
70 static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
71 const std::array<index_t, Rank>& xyStrides,
72 int blkGroupSize,
73 int numBlockTileIteration)
74 {
75 const auto tupleXYLengths =
76 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
77 const auto tupleXYStrides =
78 generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
79
80 const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
81
82 const auto grid_desc_m_k = [&]() {
83 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
85
86 const auto reduceDimLengths =
87 generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
89 const auto invariantDimLengths =
90 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
91
92 return transform_tensor_descriptor(raw_grid_desc,
93 make_tuple(make_merge_transform(invariantDimLengths),
94 make_merge_transform(reduceDimLengths)),
95 make_tuple(InvariantDims{}, ReduceDims{}),
97 }();
98
99 const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
100 const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
101
102 const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
103 const auto mPad =
104 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
105 const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
106
107 auto grid_desc_m_k_padded =
108 transform_tensor_descriptor(grid_desc_m_k,
109 make_tuple(make_right_pad_transform(invariantLength, mPad),
110 make_right_pad_transform(reduceLength, kPad)),
113
114 return (grid_desc_m_k_padded);
115 };
116
117 static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
118 {
119 const auto grid_desc_m_g = make_naive_tensor_descriptor(
120 make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
121
122 const auto mPad =
123 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
124
125 auto grid_desc_m_g_padded =
126 transform_tensor_descriptor(grid_desc_m_g,
127 make_tuple(make_right_pad_transform(invariantLength, mPad),
128 make_pass_through_transform(blkGroupSize)),
131
132 return (grid_desc_m_g_padded);
133 };
134
135 static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
136 {
137 const auto reduceLength = blkGroupSize;
138 const auto grid_desc_m_k = make_naive_tensor_descriptor(
139 make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
140
141 const auto mPad =
142 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
143 const auto kPad =
144 math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
145
146 auto grid_desc_m_k_padded =
147 transform_tensor_descriptor(grid_desc_m_k,
148 make_tuple(make_right_pad_transform(invariantLength, mPad),
149 make_right_pad_transform(reduceLength, kPad)),
152
153 return (grid_desc_m_k_padded);
154 };
155
156 static auto
157 MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
158 const std::array<index_t, NumInvariantDim>& strides)
159 {
160 const auto tupleLengths =
161 generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
162 const auto tupleStrides =
163 generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
164
165 auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
166
167 auto grid_desc_m = transform_tensor_descriptor(
168 raw_grid_desc,
169 make_tuple(make_merge_transform(tupleLengths)),
172
173 const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
174
175 const auto mPad =
176 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
177
178 auto grid_desc_m_padded =
179 transform_tensor_descriptor(grid_desc_m,
180 make_tuple(make_right_pad_transform(invariantLength, mPad)),
183 return (grid_desc_m_padded);
184 };
185
186 using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
188
189 struct Argument : public BaseArgument
190 {
191 Argument(const std::array<index_t, Rank> xyLengths,
192 const std::array<index_t, Rank> xStrides,
193 const std::array<index_t, Rank> yStrides,
194 const std::array<int, NumBatchNormReduceDim> reduceDims,
195 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
196 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
197 const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
198 const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
199 const XDataType* p_x,
200 const ScaleDataType* p_scale,
201 const BiasDataType* p_bias,
202 const YElementwiseOp y_elementwise_op,
203 double epsilon,
204 YDataType* p_y,
205 MeanVarDataType* resultSaveMean,
206 MeanVarDataType* resultSaveInvVariance,
207 double averageFactor,
208 MeanVarDataType* resultRunningMean,
209 MeanVarDataType* resultRunningVariance)
210 : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
211 bnScaleStrides_(bnScaleStrides),
212 bnBiasStrides_(bnBiasStrides),
213 bnMeanVarStrides_(bnMeanVarStrides),
214 p_x_(p_x),
215 p_scale_(p_scale),
216 p_bias_(p_bias),
217 y_elementwise_op_(y_elementwise_op),
218 p_y_(p_y),
219 resultSaveMean_(resultSaveMean),
220 resultSaveInvVariance_(resultSaveInvVariance),
221 resultRunningMean_(resultRunningMean),
222 resultRunningVariance_(resultRunningVariance)
223 {
224 xyLengths_ =
226 xStrides_ =
228 yStrides_ =
230
233
236
238 (resultRunningMean != nullptr && resultRunningVariance != nullptr);
239 saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
240
241 if(UseMultiblockInK)
242 {
243 int iterations = 1;
244 while(true)
245 {
246 int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
247 (K_BlockTileSize * iterations);
248
249 // we want the blkGroupSize be not more than 16
250 if(testBlkGroupSize <= 16)
251 break;
252
253 iterations++;
254 };
255
256 blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
257 (K_BlockTileSize * iterations);
258
259 numBlockTileIteration_ = iterations;
260 }
261 else
262 {
263 blkGroupSize_ = 1;
265 };
266
268
274 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
276 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
278 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
279 }
280
281 AccDataType epsilon_;
282 AccDataType averageFactor_;
283
286
287 std::array<index_t, Rank> xyLengths_;
288 std::array<index_t, Rank> xStrides_;
289 std::array<index_t, Rank> yStrides_;
290
291 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
292 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
293 std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
294 std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
295
296 const XDataType* p_x_;
297 const ScaleDataType* p_scale_;
298 const BiasDataType* p_bias_;
299 const YElementwiseOp y_elementwise_op_;
300 YDataType* p_y_;
301
302 MeanVarDataType* resultSaveMean_;
303 MeanVarDataType* resultSaveInvVariance_;
304
305 MeanVarDataType* resultRunningMean_;
306 MeanVarDataType* resultRunningVariance_;
307
310
313 size_t gridSize_;
314
320
324
325 void* control_;
326 };
327
328 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
329 {
330 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
331
332 size_t workspace_size = 0;
333
334 if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
335 {
336 // workspace for welford intermediate mean
337 workspace_size +=
338 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
339
340 // workspace for welford intermediate variance
341 workspace_size +=
342 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
343
344 // workspace for welford intermediate count
345 workspace_size +=
346 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
347
348 // workspace for barrier objects, each barrier object consists of two integers
349 // TODO: allocate barrier object memory globally to reuse it by other operators
350 workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
351 sizeof(int) * 2;
352 }
353
354 return (workspace_size);
355 };
356
358 void* p_workspace,
359 const StreamConfig& = StreamConfig{}) const override
360 {
361 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
362
363 pArg_->p_workspace_ = p_workspace;
364
365 if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
366 {
367 // setup buffer used for intermediate welford mean
368 pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
369
370 index_t mean_space_sz =
371 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
372
373 mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
374
375 // setup buffer used for intermediate welford varirance
376 pArg_->workspace_variance_ =
377 reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
378
379 index_t variance_space_sz =
380 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
381
382 variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
383
384 // setup buffer used for intermediate welfor count
385 pArg_->workspace_count_ =
386 reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
387
388 index_t count_space_sz =
389 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
390
391 count_space_sz = math::integer_least_multiple(count_space_sz, 64);
392
393 pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
394
395 index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
396 M_BlockTileSize * sizeof(int) * 2;
397
398 hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
399 };
400 };
401
402 struct Invoker : public BaseInvoker
403 {
404 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
405 {
406 float avg_time = 0;
407
408 if(UseMultiblockInK && arg.blkGroupSize_ > 1)
409 {
410 using GetReduceCountPerThreadFunctor =
412
413 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
415
416 const auto mean_var_count_grid_desc_m_g =
419
420 const auto mean_var_count_grid_desc_m_k =
423
424 using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
425 using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
426
427 using GridwiseMultiblockBatchNormForward_ =
429 YDataType,
430 AccDataType,
431 ScaleDataType,
432 BiasDataType,
433 MeanVarDataType,
434 YElementwiseOp,
436 MeanVarCountGridDesc_M_G,
437 MeanVarCountGridDesc_M_K,
440 GetReduceCountPerThreadFunctor,
441 BlockSize,
442 MThreadClusterSize,
443 KThreadClusterSize,
444 MThreadSliceSize,
445 KThreadSliceSize,
446 XSrcYDstVectorDim,
447 XSrcVectorSize,
448 YDstVectorSize,
449 ScaleSrcVectorSize,
450 BiasSrcVectorSize,
451 MeanVarSrcDstVectorSize>;
452
453 using GridwiseMultiblockWelfordFirstHalf_ =
455 AccDataType,
456 MeanVarDataType,
458 MeanVarCountGridDesc_M_G,
459 GetReduceCountPerThreadFunctor,
460 BlockSize,
461 MThreadClusterSize,
462 KThreadClusterSize,
463 MThreadSliceSize,
464 KThreadSliceSize,
465 XSrcYDstVectorDim,
466 XSrcVectorSize>;
467
468 using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
470 YDataType,
471 AccDataType,
472 ScaleDataType,
473 BiasDataType,
474 MeanVarDataType,
475 YElementwiseOp,
477 MeanVarCountGridDesc_M_K,
480 BlockSize,
481 MThreadClusterSize,
482 KThreadClusterSize,
483 MThreadSliceSize,
484 KThreadSliceSize,
485 XSrcYDstVectorDim,
486 XSrcVectorSize,
487 YDstVectorSize,
488 ScaleSrcVectorSize,
489 BiasSrcVectorSize,
490 MeanVarSrcDstVectorSize>;
491
492 // It is found that:
493 // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
494 // two-kernel method for gfx1030
495 // 2) Profiler on gfx908 could hang even though it works when running examples
496 // 3) Single-kernel method works on gfx1100, but the performance it not better
497 // than two-kernel method (due to more warps participating the barrier)
498 if(ck::get_device_name() == "gfx90a")
499 {
500 const auto kern_multiblock_batchnorm_fwd_ =
501 kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
502 XDataType,
503 YDataType,
504 AccDataType,
505 ScaleDataType,
506 BiasDataType,
507 MeanVarDataType,
508 YElementwiseOp,
510 MeanVarCountGridDesc_M_G,
511 MeanVarCountGridDesc_M_K,
514 GetReduceCountPerThreadFunctor>;
515
516 avg_time += launch_and_time_kernel(
517 stream_config,
518 kern_multiblock_batchnorm_fwd_,
519 dim3(arg.gridSize_),
520 dim3(BlockSize),
521 0,
524 mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
525 // workspace by multiple workgroups
526 mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
527 // workspace by each workgroup
531 get_reduce_count_per_thread,
533 arg.epsilon_,
534 arg.p_x_,
535 static_cast<MeanVarDataType*>(arg.workspace_mean_),
536 static_cast<MeanVarDataType*>(arg.workspace_variance_),
537 static_cast<int32_t*>(arg.workspace_count_),
538 static_cast<int*>(arg.control_),
539 arg.p_scale_,
540 arg.p_bias_,
542 arg.p_y_,
543 arg.updateMovingAverage_, // true or false
544 arg.averageFactor_,
547 arg.saveMeanInvVariance_, // true or false
548 arg.resultSaveMean_,
550 }
551 else
552 {
553 const auto kern_multiblock_welford_first_half =
554 kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
555 XDataType,
556 MeanVarDataType,
558 MeanVarCountGridDesc_M_G,
559 GetReduceCountPerThreadFunctor>;
560
561 const auto kern_welford_second_half_batchnorm_forward_final =
563 GridwiseWelfordSecondHalfBatchNormForwardFinal_,
564 XDataType,
565 YDataType,
566 AccDataType,
567 ScaleDataType,
568 BiasDataType,
569 MeanVarDataType,
570 YElementwiseOp,
572 MeanVarCountGridDesc_M_K,
575
576 avg_time += launch_and_time_kernel(
577 stream_config,
578 kern_multiblock_welford_first_half,
579 dim3(arg.gridSize_),
580 dim3(BlockSize),
581 0,
583 mean_var_count_grid_desc_m_g,
584 get_reduce_count_per_thread,
586 arg.p_x_,
587 static_cast<MeanVarDataType*>(arg.workspace_mean_),
588 static_cast<MeanVarDataType*>(arg.workspace_variance_),
589 static_cast<int32_t*>(arg.workspace_count_));
590
591 avg_time += launch_and_time_kernel(
592 stream_config,
593 kern_welford_second_half_batchnorm_forward_final,
594 dim3(arg.gridSize_),
595 dim3(BlockSize),
596 0,
599 mean_var_count_grid_desc_m_k,
603 arg.blkGroupSize_,
605 arg.epsilon_,
606 static_cast<MeanVarDataType*>(arg.workspace_mean_),
607 static_cast<MeanVarDataType*>(arg.workspace_variance_),
608 static_cast<int32_t*>(arg.workspace_count_),
609 arg.p_x_,
610 arg.p_scale_,
611 arg.p_bias_,
613 arg.p_y_,
615 arg.averageFactor_,
619 arg.resultSaveMean_,
621 };
622 }
623 else
624 {
625 using GetReduceCountPerThreadFunctor =
626 GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
627
628 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
630
631 using GridwiseBatchNormForwardWithBlockwiseWelford_ =
633 YDataType,
634 AccDataType,
635 ScaleDataType,
636 BiasDataType,
637 MeanVarDataType,
638 YElementwiseOp,
642 GetReduceCountPerThreadFunctor,
643 BlockSize,
644 MThreadClusterSize,
645 KThreadClusterSize,
646 MThreadSliceSize,
647 KThreadSliceSize,
648 XSrcYDstVectorDim,
649 XSrcVectorSize,
650 YDstVectorSize,
651 ScaleSrcVectorSize,
652 BiasSrcVectorSize,
653 MeanVarSrcDstVectorSize>;
654
655 const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
656 GridwiseBatchNormForwardWithBlockwiseWelford_,
657 XDataType,
658 YDataType,
659 AccDataType,
660 ScaleDataType,
661 BiasDataType,
662 MeanVarDataType,
663 YElementwiseOp,
667 GetReduceCountPerThreadFunctor>;
668
669 avg_time += launch_and_time_kernel(stream_config,
670 kern_batchnorm_fwd,
671 dim3(arg.gridSize_),
672 dim3(BlockSize),
673 0,
679 get_reduce_count_per_thread,
681 arg.epsilon_,
682 arg.p_x_,
683 arg.p_scale_,
684 arg.p_bias_,
686 arg.p_y_,
687 arg.updateMovingAverage_, // true or false
688 arg.averageFactor_,
691 arg.saveMeanInvVariance_, // true or false
692 arg.resultSaveMean_,
694 };
695
696 return (avg_time);
697 };
698
699 float Run(const BaseArgument* pArg,
700 const StreamConfig& stream_config = StreamConfig{}) override
701 {
702 return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
703 };
704 };
705
706 bool IsSupportedArgument(const BaseArgument* pArg) override
707 {
708 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
709
710 if constexpr(XSrcYDstVectorDim == 0)
711 {
712 if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
713 pArg_->yStrides_[NumInvariantDim - 1] != 1)
714 return false;
715
716 if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
717 pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
718 return false;
719 }
720 else
721 {
722 if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
723 return false;
724
725 if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
726 pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
727 return false;
728 };
729
730 if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
731 return false;
732 if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
733 return false;
734
735 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
736 return false;
737 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
738 return false;
739
740 if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
741 return false;
742
743 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
744 return false;
745
746 bool is_valid = true;
747
749 if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
750 is_valid = false;
751 });
752
753 if(!is_valid)
754 return false;
755
756 return true;
757 };
758
759 std::unique_ptr<BaseArgument> MakeArgumentPointer(
760 const std::array<index_t, Rank> xyLengths,
761 const std::array<index_t, Rank> xStrides,
762 const std::array<index_t, Rank> yStrides,
763 const std::array<int, NumBatchNormReduceDim> reduceDims,
764 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
765 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
766 const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
767 const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
768 const void* p_x,
769 const void* p_scale,
770 const void* p_bias,
771 double epsilon,
772 const YElementwiseOp y_elementwise_op,
773 void* p_y,
774 void* resultSaveMean,
775 void* resultSaveInvVariance,
776 double averageFactor,
777 void* resultRunningMean,
778 void* resultRunningVariance) override
779 {
780 return std::make_unique<Argument>(xyLengths,
781 xStrides,
782 yStrides,
783 reduceDims,
784 bnScaleBiasMeanVarLengths,
785 bnScaleStrides,
786 bnBiasStrides,
787 bnMeanVarStrides,
788 static_cast<const XDataType*>(p_x),
789 static_cast<const ScaleDataType*>(p_scale),
790 static_cast<const BiasDataType*>(p_bias),
791 y_elementwise_op,
792 epsilon,
793 static_cast<YDataType*>(p_y),
794 static_cast<MeanVarDataType*>(resultSaveMean),
795 static_cast<MeanVarDataType*>(resultSaveInvVariance),
796 averageFactor,
797 static_cast<MeanVarDataType*>(resultRunningMean),
798 static_cast<MeanVarDataType*>(resultRunningVariance));
799 };
800
801 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
802 {
803 return std::make_unique<Invoker>();
804 };
805
806 std::string GetTypeString() const override
807 {
808 auto str = std::stringstream();
809
810 // clang-format off
811 str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
812 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
813 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
814 str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
815 str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
816 // clang-format on
817
818 return str.str();
819 }
820};
821
822} // namespace device
823} // namespace tensor_operation
824} // namespace ck
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:21
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_welford_second_half_batchnorm_forward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_batchnorm_forward.hpp:31
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_batchnorm_forward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:27
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:94
Definition gridwise_multiblock_batchnorm_forward.hpp:112
Definition gridwise_multiblock_welford_first_half.hpp:55
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:102
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batchnorm_forward.hpp:26
Definition device_batchnorm_forward_impl.hpp:190
MeanVarDataType * resultRunningMean_
Definition device_batchnorm_forward_impl.hpp:305
long_index_t reduce_length_
Definition device_batchnorm_forward_impl.hpp:309
const ScaleDataType * p_scale_
Definition device_batchnorm_forward_impl.hpp:297
bool updateMovingAverage_
Definition device_batchnorm_forward_impl.hpp:284
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:317
std::array< index_t, Rank > xStrides_
Definition device_batchnorm_forward_impl.hpp:288
const XDataType * p_x_
Definition device_batchnorm_forward_impl.hpp:296
std::array< index_t, Rank > xyLengths_
Definition device_batchnorm_forward_impl.hpp:287
int blkGroupSize_
Definition device_batchnorm_forward_impl.hpp:311
XYGridDesc_M_K x_grid_desc_m_k_
Definition device_batchnorm_forward_impl.hpp:315
XYGridDesc_M_K y_grid_desc_m_k_
Definition device_batchnorm_forward_impl.hpp:316
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:318
bool saveMeanInvVariance_
Definition device_batchnorm_forward_impl.hpp:285
long_index_t invariant_length_
Definition device_batchnorm_forward_impl.hpp:308
MeanVarDataType * resultRunningVariance_
Definition device_batchnorm_forward_impl.hpp:306
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const XDataType *p_x, const ScaleDataType *p_scale, const BiasDataType *p_bias, const YElementwiseOp y_elementwise_op, double epsilon, YDataType *p_y, MeanVarDataType *resultSaveMean, MeanVarDataType *resultSaveInvVariance, double averageFactor, MeanVarDataType *resultRunningMean, MeanVarDataType *resultRunningVariance)
Definition device_batchnorm_forward_impl.hpp:191
AccDataType averageFactor_
Definition device_batchnorm_forward_impl.hpp:282
const BiasDataType * p_bias_
Definition device_batchnorm_forward_impl.hpp:298
AccDataType epsilon_
Definition device_batchnorm_forward_impl.hpp:281
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:319
std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides_
Definition device_batchnorm_forward_impl.hpp:293
int numBlockTileIteration_
Definition device_batchnorm_forward_impl.hpp:312
void * workspace_count_
Definition device_batchnorm_forward_impl.hpp:323
const YElementwiseOp y_elementwise_op_
Definition device_batchnorm_forward_impl.hpp:299
void * workspace_mean_
Definition device_batchnorm_forward_impl.hpp:321
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition device_batchnorm_forward_impl.hpp:294
void * control_
Definition device_batchnorm_forward_impl.hpp:325
MeanVarDataType * resultSaveMean_
Definition device_batchnorm_forward_impl.hpp:302
YDataType * p_y_
Definition device_batchnorm_forward_impl.hpp:300
size_t gridSize_
Definition device_batchnorm_forward_impl.hpp:313
MeanVarDataType * resultSaveInvVariance_
Definition device_batchnorm_forward_impl.hpp:303
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition device_batchnorm_forward_impl.hpp:291
void * workspace_variance_
Definition device_batchnorm_forward_impl.hpp:322
std::array< index_t, Rank > yStrides_
Definition device_batchnorm_forward_impl.hpp:289
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition device_batchnorm_forward_impl.hpp:292
Definition device_batchnorm_forward_impl.hpp:403
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batchnorm_forward_impl.hpp:404
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batchnorm_forward_impl.hpp:699
Definition device_batchnorm_forward_impl.hpp:56
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_batchnorm_forward_impl.hpp:70
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition device_batchnorm_forward_impl.hpp:706
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batchnorm_forward_impl.hpp:801
static constexpr index_t K_BlockTileSize
Definition device_batchnorm_forward_impl.hpp:68
std::string GetTypeString() const override
Definition device_batchnorm_forward_impl.hpp:806
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_batchnorm_forward_impl.hpp:357
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition device_batchnorm_forward_impl.hpp:186
static constexpr index_t M_BlockTileSize
Definition device_batchnorm_forward_impl.hpp:67
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasMeanVarGridDesc_M
Definition device_batchnorm_forward_impl.hpp:187
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition device_batchnorm_forward_impl.hpp:157
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const void *p_x, const void *p_scale, const void *p_bias, double epsilon, const YElementwiseOp y_elementwise_op, void *p_y, void *resultSaveMean, void *resultSaveInvVariance, double averageFactor, void *resultRunningMean, void *resultRunningVariance) override
Definition device_batchnorm_forward_impl.hpp:759
static constexpr index_t NumInvariantDim
Definition device_batchnorm_forward_impl.hpp:65
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_batchnorm_forward_impl.hpp:328
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_forward_impl.hpp:135
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_forward_impl.hpp:117