device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp Source File

device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp Source File
device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
22
23namespace ck {
24namespace tensor_operation {
25namespace device {
26
27// Conv backward data multiple D:
28// input : output image A: [G, N, K, Ho, Wo]
29// input : weight B: [G, K, C, Y, X],
30// input : D0, D1, ... : [G, N, K, Ho, Wo]
31// output : input image E: [G, N, C, Hi, Wi]
32// C = a_op(A) * b_op(B)
33// E = cde_op(C, D0, D1, ...)
34template <index_t NDimSpatial,
35 typename ALayout, // output image
36 typename BLayout, // weight
37 typename DsLayout, // bias
38 typename ELayout, // input image
39 typename ADataType, // output image
40 typename BDataType, // weight
41 typename AccDataType,
42 typename CShuffleDataType,
43 typename DsDataType, // bias
44 typename EDataType, // input image
45 typename AElementwiseOp, // output image
46 typename BElementwiseOp, // weight
47 typename CDEElementwiseOp, // C, bias, and input image
48 ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
49 ck::index_t BlockSize,
50 ck::index_t MPerBlock,
51 ck::index_t NPerBlock,
52 ck::index_t K0PerBlock,
53 ck::index_t K1,
54 ck::index_t MPerWMMA,
55 ck::index_t NPerWMMA,
56 ck::index_t MRepeat,
57 ck::index_t NRepeat,
58 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
59 typename ABlockTransferThreadClusterArrangeOrder,
60 typename ABlockTransferSrcAccessOrder,
61 index_t ABlockTransferSrcVectorDim,
62 index_t ABlockTransferSrcScalarPerVector,
63 index_t ABlockTransferDstScalarPerVector_AK1,
64 bool ABlockLdsExtraM,
65 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
66 typename BBlockTransferThreadClusterArrangeOrder,
67 typename BBlockTransferSrcAccessOrder,
68 index_t BBlockTransferSrcVectorDim,
69 index_t BBlockTransferSrcScalarPerVector,
70 index_t BBlockTransferDstScalarPerVector_BK1,
71 bool BBlockLdsExtraN,
72 index_t CShuffleMRepeatPerShuffle,
73 index_t CShuffleNRepeatPerShuffle,
74 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
75 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
76 index_t NumGemmKPrefetchStage = 1,
80 : public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
81 ALayout, // output image
82 BLayout, // weight
83 DsLayout, // bias
84 ELayout, // input image
85 ADataType, // output image
86 BDataType, // weight
87 DsDataType, // bias
88 EDataType, // input image
89 AElementwiseOp,
90 BElementwiseOp,
91 CDEElementwiseOp>
92{
93 // TODO: Extend support for more spatial dimensions.
94 static_assert(NDimSpatial == 2 || NDimSpatial == 3,
95 "wrong! only implemented for 2D and 3D now");
96
98
99 static constexpr index_t NumDTensor = DsDataType::Size();
100
101 // TODO: Add support for different A and B data types.
102 using ABDataType = ADataType;
103
104 static constexpr auto I0 = Number<0>{};
105 static constexpr auto I1 = Number<1>{};
106 static constexpr auto I2 = Number<2>{};
107 static constexpr auto I3 = Number<3>{};
108 static constexpr index_t KPerBlock = K0PerBlock * K1;
109
111 ConvBackwardDataSpecialization,
112 K1,
113 K1,
114 MPerBlock,
115 NPerBlock,
116 KPerBlock,
117 true /* DoPadGemmM */,
118 true /* DoPadGemmN */,
119 ALayout,
120 BLayout,
121 ELayout>;
122
123 static auto
125 {
126 const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
127 const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
128 const auto ds_grid_desc_m_n =
129 generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
131 const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
132 return make_tuple(
133 a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
134 }
135
136 // desc
139
144
145 // GridwiseGemm
147 // DataType Family
148 ADataType,
149 BDataType,
150 AccDataType,
151 CShuffleDataType,
152 DsDataType,
153 EDataType,
154 // InMemory Data Descriptor
159 // ElementwiseOp Family
160 AElementwiseOp,
161 BElementwiseOp,
162 CDEElementwiseOp,
164 // Tiling Family
165 MPerBlock,
166 NPerBlock,
167 KPerBlock,
168 MPerWMMA,
169 NPerWMMA,
170 K1,
171 MRepeat,
172 NRepeat,
173 // ThreadCluster Family
174 BlockSize,
175 ABlockTransferThreadClusterLengths_AK0_M_AK1,
176 ABlockTransferThreadClusterArrangeOrder,
177 ABlockTransferSrcAccessOrder,
178 ABlockTransferSrcVectorDim,
179 ABlockTransferSrcScalarPerVector,
180 ABlockTransferDstScalarPerVector_AK1,
181 false,
182 true,
183 ABlockLdsExtraM,
184 BBlockTransferThreadClusterLengths_BK0_N_BK1,
185 BBlockTransferThreadClusterArrangeOrder,
186 BBlockTransferSrcAccessOrder,
187 BBlockTransferSrcVectorDim,
188 BBlockTransferSrcScalarPerVector,
189 BBlockTransferDstScalarPerVector_BK1,
190 false,
191 true,
192 BBlockLdsExtraN,
193 CShuffleMRepeatPerShuffle,
194 CShuffleNRepeatPerShuffle,
195 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
196 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
197 NumGemmKPrefetchStage,
198 LoopSched,
199 PipelineVer>;
200
203 DsGridDesc_M_N{}));
206 EGridDesc_M_N{}));
207
208 // Argument
209 struct Argument : public BaseArgument
210 {
211 Argument(const void* p_a, // output image
212 const void* p_b, // weight
213 const std::array<const void*, NumDTensor>& p_ds, // bias
214 void* p_e, // input image
215 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
216 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
217 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
218 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
219 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
220 /*ds_g_n_c_wis_lengths*/,
221 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
222 ds_g_n_c_wis_strides,
223 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
224 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides,
225 const std::array<index_t, NDimSpatial>& conv_filter_strides,
226 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
227 const std::array<index_t, NDimSpatial>& input_left_pads,
228 const std::array<index_t, NDimSpatial>& input_right_pads,
229 const AElementwiseOp& a_element_op,
230 const BElementwiseOp& b_element_op,
231 const CDEElementwiseOp& cde_element_op,
232 const ck::index_t split_k = 1)
233 : p_a_grid_{static_cast<const ADataType*>(p_a)},
234 p_b_grid_{static_cast<const BDataType*>(p_b)},
235 p_ds_grid_{},
236 p_e_grid_{static_cast<EDataType*>(p_e)},
237 num_group_{a_g_n_k_wos_lengths[0]},
238 a_element_op_{a_element_op},
239 b_element_op_{b_element_op},
240 cde_element_op_{cde_element_op},
241 a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
242 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
243 conv_filter_strides_{conv_filter_strides},
244 input_left_pads_{input_left_pads},
245 input_right_pads_{input_right_pads},
246 k_batch_{split_k}
247 {
248 bool image_covered_dilation = true;
249 bool image_covered_strides = true;
250 for(index_t d = 0; d < NDimSpatial; d++)
251 {
252 // If dilation and stride is not equal to the we will have some empty places
253 image_covered_dilation &=
254 conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
255 // If stride is larger than windows size then we will have some empty places
256 image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
257 }
258 bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
261 e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
262 sizeof(EDataType);
263
264 // populate Ds pointer
265 static_for<0, NumDTensor, 1>{}([&](auto i) {
266 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
267
268 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
269 });
270
271 // A/B/Ds/E Batch Stride
272 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
273 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
274 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
275
276 static_for<0, NumDTensor, 1>{}([&](auto i) {
277 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
278 });
279
280 static constexpr auto NonSpatialDimsNum = Number<3>{};
281
282 static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
283 static constexpr auto HIdx =
285 static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
287
288 static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
289 static constexpr auto YIdx =
291 static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
293
294 // problem definition
295 const index_t Z = b_g_k_c_xs_lengths[ZIdx];
296 const index_t Y = b_g_k_c_xs_lengths[YIdx];
297 const index_t X = b_g_k_c_xs_lengths[XIdx];
298
299 const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
300 const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
301 const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
302
303 const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
304 const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
305 const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
306
307 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
308 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
309 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
310
311 const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
312 const auto YTilde = ConvStrideH / GcdStrideDilationH;
313 const auto XTilde = ConvStrideW / GcdStrideDilationW;
314
315 for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
316 {
317
318 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
319 {
320 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
321 {
322 // check slice is valid
323 const auto ZDotSlice =
324 NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
325 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
326 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
327
328 if(YDotSlice * XDotSlice * ZDotSlice <= 0)
329 {
330 continue;
331 }
332
333 std::array<index_t, NDimSpatial> tildes;
334 if constexpr(NDimSpatial == 2)
335 {
336 tildes = {i_ytilde, i_xtilde};
337 }
338 else if constexpr(NDimSpatial == 3)
339 {
340 tildes = {i_ztilde, i_ytilde, i_xtilde};
341 }
342
343 ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
344 a_g_n_k_wos_strides,
345 b_g_k_c_xs_lengths,
346 b_g_k_c_xs_strides,
347 e_g_n_c_wis_lengths,
348 e_g_n_c_wis_strides,
349 conv_filter_strides,
350 conv_filter_dilations,
351 input_left_pads,
352 input_right_pads,
353 tildes};
354
355 const auto a_grid_desc_ak0_m_ak1 =
356 conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
357
358 const auto b_grid_desc_bk0_n_bk1 =
359 conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
360
361 DsGridDesc_M_N ds_grid_desc_m_n;
362
363 // populate Ds desc
364 static_for<0, NumDTensor, 1>{}([&](auto i) {
365 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
366 static_assert(is_same_v<DLayout, ELayout>);
367 ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
368 a_g_n_k_wos_lengths,
369 a_g_n_k_wos_strides,
370 b_g_k_c_xs_lengths,
371 b_g_k_c_xs_strides,
372 e_g_n_c_wis_lengths,
373 ds_g_n_c_wis_strides[i],
374 conv_filter_strides,
375 conv_filter_dilations,
376 input_left_pads,
377 input_right_pads,
378 tildes};
379
380 ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
381 });
382
383 const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
384
385 // for check validity
386 ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
387 e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
388
389 // desc for blockwise copy
390 a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
391 b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
392
393 // block-to-e-tile-map
394 auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(
395 e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */);
396
397 block_2_ctile_map_container_.push_back(block_2_ctile_map);
398
401 ds_grid_desc_m_n));
404 e_grid_desc_m_n));
405 }
406 }
407 }
408 }
409
410 void Print() const
411 {
412 for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
413 {
414 std::cout << "a_grid_desc_ak0_m_ak1_container_"
415 << a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
416
417 std::cout << "b_grid_desc_bk0_n_bk1_container_"
418 << b_grid_desc_bk0_n_bk1_container_[i] << std::endl;
419
420 static_for<0, NumDTensor, 1>{}([&](auto j) {
421 std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
423 << std::endl;
424 });
425
426 std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
428 << std::endl;
429 }
430 }
431
432 // pointers
433 const ADataType* p_a_grid_;
434 const BDataType* p_b_grid_;
436 EDataType* p_e_grid_;
437
438 // tensor descriptor for problem definition
440 std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
441 std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
442
443 // tensor descriptor for block-wise copy
444 std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
445 std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;
446 std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
448 std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
450
451 // block-to-e-tile map
452 std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
453
454 // for computing batch offset
455 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
456
457 // element-wise op
458 AElementwiseOp a_element_op_;
459 BElementwiseOp b_element_op_;
460 CDEElementwiseOp cde_element_op_;
461
462 std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
463 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
464 std::array<index_t, NDimSpatial> conv_filter_strides_;
465 std::array<index_t, NDimSpatial> input_left_pads_;
466 std::array<index_t, NDimSpatial> input_right_pads_;
467
471 };
472
473 // Invoker
474 struct Invoker : public BaseInvoker
475 {
477
478 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
479 {
480 if(stream_config.log_level_ > 0)
481 {
482 arg.Print();
483 }
484
485 float ave_time = 0;
486
487 for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
488 {
489 const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
491 arg.num_group_;
492
493 const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
494 arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
495
496 const auto clear_workspace = [&]() {
497 if(arg.bwd_needs_zero_out && i == 0)
498 {
499 hip_check_error(hipMemsetAsync(
500 arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_));
501 }
502 };
503
504 auto launch_kernel = [&](auto has_main_k_block_loop) {
505 constexpr bool has_main_loop = has_main_k_block_loop.value;
506
508 GridwiseGemm,
509 ADataType,
510 BDataType,
512 EDataType,
513 AElementwiseOp,
514 BElementwiseOp,
515 CDEElementwiseOp,
521 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
522 has_main_loop>;
523
525 stream_config,
526 clear_workspace,
527 kernel,
528 dim3(grid_size),
529 dim3(BlockSize),
530 0,
531 arg.p_a_grid_,
532 arg.p_b_grid_,
533 arg.p_ds_grid_,
534 arg.p_e_grid_,
535 arg.a_element_op_,
536 arg.b_element_op_,
537 arg.cde_element_op_,
538 arg.a_g_n_k_wos_lengths_[0], // Group count
545 };
546
548 {
549 ave_time += launch_kernel(integral_constant<bool, true>{});
550 }
551 else
552 {
553 ave_time += launch_kernel(integral_constant<bool, false>{});
554 }
555 }
556
557 return ave_time;
558 }
559
560 float Run(const BaseArgument* p_arg,
561 const StreamConfig& stream_config = StreamConfig{}) override
562 {
563 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
564 }
565 };
566
567 static bool IsSupportedArgument(const Argument& arg)
568 {
569 if(arg.k_batch_ != 1)
570 {
571 return false;
572 }
573
574 // check device
576 {
578 {
579 return false;
580 }
581 }
582 else
583 {
584 return false;
585 }
586
587 const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
588 const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
589
590 // Specialization
591 if constexpr(ConvBackwardDataSpecialization ==
593 {
594 // check if it's a 1x1 convolution with stride=1 and no padding
595 for(int i = 0; i < NDimSpatial; i++)
596 {
597 if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 &&
598 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
599 {
600 return false;
601 }
602 }
603 }
604
605 // vector load for A matrix from global memory to LDS
610 {
611 if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
612 {
613 return false;
614 }
615 }
616 else
617 {
618 return false;
619 }
620
621 // vector load for B matrix from global memory to LDS
624 {
625 if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
626 {
627 return false;
628 }
629 }
630 else
631 {
632 return false;
633 }
634
635 // vector store for Ds
636 bool ds_valid = true;
637
638 static_for<0, NumDTensor, 1>{}([&](auto i) {
639 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
640
648 {
649 // vector load D matrix from global memory
650 if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
651 {
652 ds_valid = false;
653 }
654 }
655 else
656 {
657 ds_valid = false;
658 }
659 });
660
661 if(!ds_valid)
662 {
663 return false;
664 }
665
666 // vector store for E
671 {
672 // vector store C matrix into global memory
673 if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
674 {
675 return false;
676 }
677 }
678 else
679 {
680 return false;
681 }
682
683 // Gridwise GEMM size
684 for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
685 {
691 {
692 return false;
693 }
694 }
695
696 // check number of dimension, only implemented for 2D and 3D now
697 if(NDimSpatial != 2 && NDimSpatial != 3)
698 {
699 return false;
700 }
701
702 return true;
703 }
704
705 bool IsSupportedArgument(const BaseArgument* p_arg) override
706 {
707 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
708 }
709
710 static auto
711 MakeArgument(const void* p_a, // output image
712 const void* p_b, // weight
713 const std::array<const void*, NumDTensor>& p_ds, // bias
714 void* p_e, // input image
715 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
716 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
717 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
718 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
719 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
720 ds_g_n_c_wis_lengths, // bias
721 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
722 ds_g_n_c_wis_strides, // bias
723 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
724 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
725 const std::array<index_t, NDimSpatial>& conv_filter_strides,
726 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
727 const std::array<index_t, NDimSpatial>& input_left_pads,
728 const std::array<index_t, NDimSpatial>& input_right_pads,
729 const AElementwiseOp& a_element_op,
730 const BElementwiseOp& b_element_op,
731 const CDEElementwiseOp& cde_element_op,
732 const ck::index_t split_k = 1)
733 {
734 return Argument{p_a,
735 p_b,
736 p_ds,
737 p_e,
738 a_g_n_k_wos_lengths,
739 a_g_n_k_wos_strides,
740 b_g_k_c_xs_lengths,
741 b_g_k_c_xs_strides,
742 ds_g_n_c_wis_lengths,
743 ds_g_n_c_wis_strides,
744 e_g_n_c_wis_lengths,
745 e_g_n_c_wis_strides,
746 conv_filter_strides,
747 conv_filter_dilations,
748 input_left_pads,
749 input_right_pads,
750 a_element_op,
751 b_element_op,
752 cde_element_op,
753 split_k};
754 }
755
756 static auto MakeInvoker() { return Invoker{}; }
757
758 std::unique_ptr<BaseArgument> MakeArgumentPointer(
759 const void* p_a, // output image
760 const void* p_b, // weight
761 const std::array<const void*, NumDTensor>& p_ds, // bias
762 void* p_e, // input image
763 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
764 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
765 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
766 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
767 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
768 ds_g_n_c_wis_lengths, // bias
769 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
770 ds_g_n_c_wis_strides, // bias
771 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
772 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
773 const std::array<index_t, NDimSpatial>& conv_filter_strides,
774 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
775 const std::array<index_t, NDimSpatial>& input_left_pads,
776 const std::array<index_t, NDimSpatial>& input_right_pads,
777 const AElementwiseOp& a_element_op,
778 const BElementwiseOp& b_element_op,
779 const CDEElementwiseOp& cde_element_op,
780 const ck::index_t split_k = 1) override
781 {
782 return std::make_unique<Argument>(p_a,
783 p_b,
784 p_ds,
785 p_e,
786 a_g_n_k_wos_lengths,
787 a_g_n_k_wos_strides,
788 b_g_k_c_xs_lengths,
789 b_g_k_c_xs_strides,
790 ds_g_n_c_wis_lengths,
791 ds_g_n_c_wis_strides,
792 e_g_n_c_wis_lengths,
793 e_g_n_c_wis_strides,
794 conv_filter_strides,
795 conv_filter_dilations,
796 input_left_pads,
797 input_right_pads,
798 a_element_op,
799 b_element_op,
800 cde_element_op,
801 split_k);
802 }
803
804 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
805 {
806 return std::make_unique<Invoker>(Invoker{});
807 }
808
809 std::string GetTypeString() const override
810 {
811 auto str = std::stringstream();
812
813 // clang-format off
814 str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"
815 << "<"
816 << BlockSize << ", "
817 << MPerBlock << ", "
818 << NPerBlock << ", "
819 << KPerBlock << ", "
820 << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
821 << K1 << ", "
822 << ABlockTransferSrcScalarPerVector << ", "
823 << BBlockTransferSrcScalarPerVector
824 << ">";
825 // clang-format on
826
827 return str.str();
828 }
829};
830
831} // namespace device
832} // namespace tensor_operation
833} // namespace ck
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization &s)
Definition convolution_backward_data_specialization.hpp:17
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
Definition functional2.hpp:33
Definition transform_conv_bwd_data_to_gemm_v1.hpp:44
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:659
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:943
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:1150
Definition device_base.hpp:197
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:210
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:465
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:466
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:464
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:211
std::vector< EGridDesc_M_N > e_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:441
void Print() const
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:410
std::vector< AGridDesc_AK0_M_AK1 > a_grid_desc_ak0_m_ak1_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:444
long_index_t e_space_size_bytes
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:470
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:455
bool bwd_needs_zero_out
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:469
index_t num_group_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:439
std::vector< DsGridDesc_M_N > ds_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:440
BElementwiseOp b_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:459
std::vector< BGridDesc_BK0_N_BK1 > b_grid_desc_bk0_n_bk1_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:445
std::vector< DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock > ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:447
std::vector< typename GridwiseGemm::DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:452
CDEElementwiseOp cde_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:460
const index_t k_batch_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:468
std::array< index_t, NDimSpatial+3 > a_g_n_k_wos_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:462
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:433
AElementwiseOp a_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:458
EDataType * p_e_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:436
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:435
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:434
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:463
std::vector< EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock > e_grid_desc_mblock_mperblock_nblock_nperblock_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:449
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:475
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:478
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:560
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:476
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:92
remove_cvref_t< tuple_element_t< 3, ABDsEGridDesc > > EGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:143
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:705
static auto MakeInvoker()
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:756
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle DeviceOp
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:97
remove_cvref_t< tuple_element_t< 1, ABDsEGridDesc > > BGridDesc_BK0_N_BK1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:141
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:99
static constexpr auto I1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:105
static constexpr auto I0
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:104
decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)) ABDsEGridDesc
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:138
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:711
static constexpr auto I3
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:107
remove_cvref_t< tuple_element_t< 2, ABDsEGridDesc > > DsGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:142
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:809
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:804
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWMMA, NPerWMMA, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, true, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseGemm
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:146
ADataType ABDataType
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:102
static constexpr ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:137
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:204
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:758
remove_cvref_t< tuple_element_t< 0, ABDsEGridDesc > > AGridDesc_AK0_M_AK1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:140
TransformConvBwdDataToGemm_v1< NDimSpatial, ConvBackwardDataSpecialization, K1, K1, MPerBlock, NPerBlock, KPerBlock, true, true, ALayout, BLayout, ELayout > ConvToGemmBwdDataTransform
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:110
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:567
static constexpr auto I2
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:106
static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform &conv_to_gemm_transform)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:124
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:201
static constexpr index_t KPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:108
Definition device_grouped_conv_bwd_data_multiple_d.hpp:36