device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp Source File

device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp Source File#

Composable Kernel: device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp Source File
device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
25template <ck::index_t NDimSpatial,
26 typename InDataType,
27 typename WeiDataType,
28 typename OutDataType,
29 typename AccDataType,
30 typename InElementwiseOperation,
31 typename WeiElementwiseOperation,
32 typename OutElementwiseOperation,
33 ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 ck::index_t MPerXDL,
40 ck::index_t NPerXDL,
41 ck::index_t MXdlPerWave,
42 ck::index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 bool BBlockLdsAddExtraN,
57 ck::index_t CThreadTransferSrcDstVectorDim,
58 ck::index_t CThreadTransferDstScalarPerVector>
60 : public DeviceConvBwdData<
61 NDimSpatial,
62 ck::tuple_element_t<NDimSpatial - 1,
63 ck::Tuple<ck::tensor_layout::convolution::NWC,
64 ck::tensor_layout::convolution::NHWC,
65 ck::tensor_layout::convolution::NDHWC>>,
66 ck::tuple_element_t<NDimSpatial - 1,
67 ck::Tuple<ck::tensor_layout::convolution::KXC,
68 ck::tensor_layout::convolution::KYXC,
69 ck::tensor_layout::convolution::KZYXC>>,
70 ck::tuple_element_t<NDimSpatial - 1,
71 ck::Tuple<ck::tensor_layout::convolution::NWK,
72 ck::tensor_layout::convolution::NHWK,
73 ck::tensor_layout::convolution::NDHWK>>,
74 InDataType,
75 WeiDataType,
76 OutDataType,
77 InElementwiseOperation,
78 WeiElementwiseOperation,
79 OutElementwiseOperation>
80{
82
84 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
85 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
86
87 using ADataType = OutDataType;
88 using BDataType = WeiDataType;
89 using CDataType = InDataType;
90
91 // TODO make A/B datatype different
92 using ABDataType = InDataType;
93
94 static constexpr auto I0 = Number<0>{};
95 static constexpr auto I1 = Number<1>{};
96 static constexpr auto I2 = Number<2>{};
97 static constexpr auto I3 = Number<3>{};
98 static constexpr auto I4 = Number<4>{};
99 static constexpr auto I5 = Number<5>{};
100 static constexpr auto I6 = Number<6>{};
101 static constexpr auto I7 = Number<7>{};
102
103 static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) %
104 ABlockTransferSrcScalarPerVector ==
105 0);
106 static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) %
107 BBlockTransferSrcScalarPerVector ==
108 0);
109
110 static constexpr auto K1Number = Number<K1>{};
111 static constexpr auto GemmK1Number = K1Number;
112
113 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
114 static auto
116 ck::index_t K,
117 ck::index_t C,
118 std::vector<ck::index_t> input_spatial_lengths,
119 std::vector<ck::index_t> filter_spatial_lengths,
120 std::vector<ck::index_t> output_spatial_lengths,
121 std::vector<ck::index_t> conv_filter_strides,
122 std::vector<ck::index_t> conv_filter_dilations,
123 std::vector<ck::index_t> input_left_pads,
124 std::vector<ck::index_t> input_right_pads,
125 std::vector<ck::index_t> tildes)
126 {
127 using namespace ck;
128
129 index_t i_xtilde = tildes[0];
130
131 const index_t Wi = input_spatial_lengths[0];
132 const index_t Wo = output_spatial_lengths[0];
133 const index_t X = filter_spatial_lengths[0];
134 const index_t InLeftPadW = input_left_pads[0];
135 const index_t InRightPadW = input_right_pads[0];
136 const index_t ConvStrideW = conv_filter_strides[0];
137 const index_t ConvDilationW = conv_filter_dilations[0];
138
139 const auto K0 = K / K1;
140
141 const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
142
143 if constexpr(ConvBackwardDataSpecialization ==
145 {
146 // A: output tensor
147 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
153
154 // B: weight tensor
155 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
161
162 // C: input tensor
163 const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
164 in_n_wi_c_grid_desc,
166 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
170
171 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
172 in_n_x_wo_c_grid_desc,
178
179 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
180 wei_gemmk0_gemmn_gemmk1_grid_desc,
181 in_gemmm_gemmn_grid_desc);
182 }
183 else
184 {
185 const auto out_n_wo_k_grid_desc =
187 const auto wei_k_x_c_grid_desc =
189
190 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
191
192 const auto XTilde = ConvStrideW / GcdStrideDilationW;
193
194 const auto XDot = math::integer_divide_ceil(X, XTilde);
195
196 const auto WTilde =
197 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
198
199 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
200 const auto IWTildeSliceBegin = math::integer_divide_floor(
201 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
202
203 const auto IWTildeSliceEnd = math::min(
204 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
205
206 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
207
208 // GemmK is different for each GEMM
209 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
210
211 // A: output tensor
212 const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
213 out_n_wo_k_grid_desc,
219
220 const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
221 out_n_wop_k_grid_desc,
224 make_embed_transform(make_tuple(XDot, WTilde),
225 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
229
230 const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
231 out_n_xdot_wtilde_k_grid_desc,
233 make_slice_transform(XDot, I0, XDotSlice),
234 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
238
239 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
240 out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
242 make_merge_transform(make_tuple(N, WTildeSlice)),
246
247 // B weight tensor
248 const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
249 wei_k_x_c_grid_desc,
251 make_embed_transform(make_tuple(XDot, XTilde),
252 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
256
257 const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
258 wei_k_xdot_xtilde_c_grid_desc,
260 make_slice_transform(XDot, I0, XDotSlice),
261 make_freeze_transform(i_xtilde),
265
266 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
267 wei_k0_k1_xdotslice_c_grid_desc,
273
274 // C: input tensor
275 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
276 in_n_wi_c_grid_desc,
278 make_pad_transform(Wi, InLeftPadW, InRightPadW),
282
283 const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
284 in_n_wip_c_grid_desc,
286 make_embed_transform(make_tuple(XTilde, WTilde),
287 make_tuple(ConvDilationW, ConvStrideW)),
291
292 const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
293 in_n_xtilde_wtilde_c_grid_desc,
295 make_freeze_transform(i_xtilde),
296 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
300
301 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
302 in_n_wtildeslice_c_grid_desc,
307
308 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
309 wei_gemmk0_gemmn_gemmk1_grid_desc,
310 in_gemmm_gemmn_grid_desc);
311 }
312
313 } // function end
314 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
315 static auto
317 ck::index_t K,
318 ck::index_t C,
319 std::vector<ck::index_t> input_spatial_lengths,
320 std::vector<ck::index_t> filter_spatial_lengths,
321 std::vector<ck::index_t> output_spatial_lengths,
322 std::vector<ck::index_t> conv_filter_strides,
323 std::vector<ck::index_t> conv_filter_dilations,
324 std::vector<ck::index_t> input_left_pads,
325 std::vector<ck::index_t> input_right_pads,
326 std::vector<ck::index_t> tildes)
327 {
328 using namespace ck;
329
330 index_t i_ytilde = tildes[0];
331 index_t i_xtilde = tildes[1];
332
333 const index_t Hi = input_spatial_lengths[0];
334 const index_t Wi = input_spatial_lengths[1];
335
336 const index_t Ho = output_spatial_lengths[0];
337 const index_t Wo = output_spatial_lengths[1];
338
339 const index_t Y = filter_spatial_lengths[0];
340 const index_t X = filter_spatial_lengths[1];
341
342 const index_t InLeftPadH = input_left_pads[0];
343 const index_t InLeftPadW = input_left_pads[1];
344
345 const index_t InRightPadH = input_right_pads[0];
346 const index_t InRightPadW = input_right_pads[1];
347
348 const index_t ConvStrideH = conv_filter_strides[0];
349 const index_t ConvStrideW = conv_filter_strides[1];
350
351 const index_t ConvDilationH = conv_filter_dilations[0];
352 const index_t ConvDilationW = conv_filter_dilations[1];
353
354 const auto K0 = K / K1;
355
356 const auto out_n_ho_wo_k_grid_desc =
358 const auto wei_k_y_x_c_grid_desc =
360 const auto in_n_hi_wi_c_grid_desc =
362
363 if constexpr(ConvBackwardDataSpecialization ==
365 {
366 // A: output tensor
367 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
373
374 // B: weight tensor
375 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
381
382 // C: input tensor
383 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
384 in_n_hi_wi_c_grid_desc,
386 make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
387 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
391
392 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
393 in_n_y_ho_x_wo_c_grid_desc,
400
401 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
402 wei_gemmk0_gemmn_gemmk1_grid_desc,
403 in_gemmm_gemmn_grid_desc);
404 }
405 else
406 {
407 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
408 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
409
410 const auto YTilde = ConvStrideH / GcdStrideDilationH;
411 const auto XTilde = ConvStrideW / GcdStrideDilationW;
412
413 const auto YDot = math::integer_divide_ceil(Y, YTilde);
414 const auto XDot = math::integer_divide_ceil(X, XTilde);
415
416 const auto HTilde =
417 Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
418 const auto WTilde =
419 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
420
421 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
422 const auto IHTildeSliceBegin = math::integer_divide_floor(
423 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
424 const auto IWTildeSliceBegin = math::integer_divide_floor(
425 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
426
427 const auto IHTildeSliceEnd = math::min(
428 HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
429 const auto IWTildeSliceEnd = math::min(
430 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
431
432 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
433 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
434
435 // GemmK is different for each GEMM
436 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
437 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
438
439 // A: output tensor
440 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
441 out_n_ho_wo_k_grid_desc,
448
449 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
450 out_n_hop_wop_k_grid_desc,
453 make_embed_transform(make_tuple(YDot, HTilde),
454 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
455 make_embed_transform(make_tuple(XDot, WTilde),
456 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
460
461 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
463 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
465 make_slice_transform(YDot, I0, YDotSlice),
466 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
467 make_slice_transform(XDot, I0, XDotSlice),
468 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
471 Sequence<1>{},
472 Sequence<2>{},
473 Sequence<3>{},
474 Sequence<4>{},
475 Sequence<5>{}),
477 Sequence<1>{},
478 Sequence<2>{},
479 Sequence<3>{},
480 Sequence<4>{},
481 Sequence<5, 6>{}));
482
483 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
484 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
485 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
486 make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
490
491 // B weight tensor
492 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
493 wei_k_y_x_c_grid_desc,
495 make_embed_transform(make_tuple(YDot, YTilde),
496 make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
497 make_embed_transform(make_tuple(XDot, XTilde),
498 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
502
503 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
504 transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
506 make_slice_transform(YDot, I0, YDotSlice),
507 make_slice_transform(XDot, I0, XDotSlice),
508 make_freeze_transform(i_ytilde),
509 make_freeze_transform(i_xtilde),
512 Sequence<1>{},
513 Sequence<3>{},
514 Sequence<2>{},
515 Sequence<4>{},
516 Sequence<5>{}),
518 Sequence<2>{},
519 Sequence<3>{},
520 Sequence<>{},
521 Sequence<>{},
522 Sequence<4>{}));
523
524 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
525 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
526 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
531
532 // C: input tensor
533 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
534 in_n_hi_wi_c_grid_desc,
536 make_pad_transform(Hi, InLeftPadH, InRightPadH),
537 make_pad_transform(Wi, InLeftPadW, InRightPadW),
541
542 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
543 in_n_hip_wip_c_grid_desc,
545 make_embed_transform(make_tuple(YTilde, HTilde),
546 make_tuple(ConvDilationH, ConvStrideH)),
547 make_embed_transform(make_tuple(XTilde, WTilde),
548 make_tuple(ConvDilationW, ConvStrideW)),
552
553 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
554 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
556 make_freeze_transform(i_ytilde),
557 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
558 make_freeze_transform(i_xtilde),
559 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
562 Sequence<1>{},
563 Sequence<2>{},
564 Sequence<3>{},
565 Sequence<4>{},
566 Sequence<5>{}),
568 Sequence<>{},
569 Sequence<1>{},
570 Sequence<>{},
571 Sequence<2>{},
572 Sequence<3>{}));
573
574 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
575 in_n_htildeslice_wtildeslice_c_grid_desc,
576 make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
580
581 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
582 wei_gemmk0_gemmn_gemmk1_grid_desc,
583 in_gemmm_gemmn_grid_desc);
584 }
585
586 } // function end
587
588 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
589 static auto
591 ck::index_t K,
592 ck::index_t C,
593 std::vector<ck::index_t> input_spatial_lengths,
594 std::vector<ck::index_t> filter_spatial_lengths,
595 std::vector<ck::index_t> output_spatial_lengths,
596 std::vector<ck::index_t> conv_filter_strides,
597 std::vector<ck::index_t> conv_filter_dilations,
598 std::vector<ck::index_t> input_left_pads,
599 std::vector<ck::index_t> input_right_pads,
600 std::vector<ck::index_t> tildes)
601 {
602 using namespace ck;
603
604 const index_t i_ztilde = tildes[0];
605 const index_t i_ytilde = tildes[1];
606 const index_t i_xtilde = tildes[2];
607
608 const index_t Di = input_spatial_lengths[0];
609 const index_t Hi = input_spatial_lengths[1];
610 const index_t Wi = input_spatial_lengths[2];
611
612 const index_t Do = output_spatial_lengths[0];
613 const index_t Ho = output_spatial_lengths[1];
614 const index_t Wo = output_spatial_lengths[2];
615
616 const index_t Z = filter_spatial_lengths[0];
617 const index_t Y = filter_spatial_lengths[1];
618 const index_t X = filter_spatial_lengths[2];
619
620 const index_t InLeftPadD = input_left_pads[0];
621 const index_t InLeftPadH = input_left_pads[1];
622 const index_t InLeftPadW = input_left_pads[2];
623
624 const index_t InRightPadD = input_right_pads[0];
625 const index_t InRightPadH = input_right_pads[1];
626 const index_t InRightPadW = input_right_pads[2];
627
628 const index_t ConvStrideD = conv_filter_strides[0];
629 const index_t ConvStrideH = conv_filter_strides[1];
630 const index_t ConvStrideW = conv_filter_strides[2];
631
632 const index_t ConvDilationD = conv_filter_dilations[0];
633 const index_t ConvDilationH = conv_filter_dilations[1];
634 const index_t ConvDilationW = conv_filter_dilations[2];
635
636 const auto K0 = K / K1;
637
638 const auto out_n_do_ho_wo_k_grid_desc =
640 const auto wei_k_z_y_x_c_grid_desc =
642 const auto in_n_di_hi_wi_c_grid_desc =
644
645 if constexpr(ConvBackwardDataSpecialization ==
647 {
648 // A: output tensor
649 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
651 make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
655
656 // B: weight tensor
657 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
663
664 // C: input tensor
665 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
666 in_n_di_hi_wi_c_grid_desc,
668 make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
669 make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
670 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
678 Sequence<7>{}));
679
680 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
681 in_n_z_do_y_ho_x_wo_c_grid_desc,
685 make_merge_transform(make_tuple(N, Do, Ho, Wo)),
688 Sequence<3>{},
689 Sequence<5>{},
691 Sequence<7>{}),
693
694 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
695 wei_gemmk0_gemmn_gemmk1_grid_desc,
696 in_gemmm_gemmn_grid_desc);
697 }
698 else
699 {
700 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
701 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
702 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
703
704 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
705 const auto YTilde = ConvStrideH / GcdStrideDilationH;
706 const auto XTilde = ConvStrideW / GcdStrideDilationW;
707
708 const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
709 const auto YDot = math::integer_divide_ceil(Y, YTilde);
710 const auto XDot = math::integer_divide_ceil(X, XTilde);
711
712 const auto DTilde =
713 Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
714 const auto HTilde =
715 Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
716 const auto WTilde =
717 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
718
719 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
720 const auto IDTildeSliceBegin = math::integer_divide_floor(
721 math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
722 const auto IHTildeSliceBegin = math::integer_divide_floor(
723 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
724 const auto IWTildeSliceBegin = math::integer_divide_floor(
725 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
726
727 const auto IDTildeSliceEnd = math::min(
728 DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
729 const auto IHTildeSliceEnd = math::min(
730 HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
731 const auto IWTildeSliceEnd = math::min(
732 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
733
734 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
735 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
736 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
737
738 // GemmK is different for each GEMM
739 const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
740 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
741 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
742
743 // A: output tensor
744 const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
745 out_n_do_ho_wo_k_grid_desc,
755
756 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
758 out_n_dop_hop_wop_k_grid_desc,
761 make_embed_transform(make_tuple(ZDot, DTilde),
762 make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
763 make_embed_transform(make_tuple(YDot, HTilde),
764 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
765 make_embed_transform(make_tuple(XDot, WTilde),
766 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
774 Sequence<7>{}));
775
776 const auto
777 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
779 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
781 make_slice_transform(ZDot, I0, ZDotSlice),
782 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
783 make_slice_transform(YDot, I0, YDotSlice),
784 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
785 make_slice_transform(XDot, I0, XDotSlice),
786 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
789 Sequence<1>{},
790 Sequence<2>{},
791 Sequence<3>{},
792 Sequence<4>{},
793 Sequence<5>{},
794 Sequence<6>{},
795 Sequence<7>{}),
797 Sequence<1>{},
798 Sequence<2>{},
799 Sequence<3>{},
800 Sequence<4>{},
801 Sequence<5>{},
802 Sequence<6>{},
803 Sequence<7, 8>{}));
804
805 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
806 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
808 make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
809 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
813
814 // B weight tensor
815 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
817 wei_k_z_y_x_c_grid_desc,
820 make_embed_transform(make_tuple(ZDot, ZTilde),
821 make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
822 make_embed_transform(make_tuple(YDot, YTilde),
823 make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
824 make_embed_transform(make_tuple(XDot, XTilde),
825 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
833 Sequence<7>{}));
834
835 const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
836 transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
838 make_slice_transform(ZDot, I0, ZDotSlice),
839 make_slice_transform(YDot, I0, YDotSlice),
840 make_slice_transform(XDot, I0, XDotSlice),
841 make_freeze_transform(i_ztilde),
842 make_freeze_transform(i_ytilde),
843 make_freeze_transform(i_xtilde),
846 Sequence<1>{},
847 Sequence<3>{},
848 Sequence<5>{},
849 Sequence<2>{},
850 Sequence<4>{},
851 Sequence<6>{},
852 Sequence<7>{}),
854 Sequence<2>{},
855 Sequence<3>{},
856 Sequence<4>{},
857 Sequence<>{},
858 Sequence<>{},
859 Sequence<>{},
860 Sequence<5>{}));
861
862 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
863 wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
864 make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
869
870 // C: input tensor
871 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
872 in_n_di_hi_wi_c_grid_desc,
874 make_pad_transform(Di, InLeftPadD, InRightPadD),
875 make_pad_transform(Hi, InLeftPadH, InRightPadH),
876 make_pad_transform(Wi, InLeftPadW, InRightPadW),
882
883 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
885 in_n_dip_hip_wip_c_grid_desc,
887 make_embed_transform(make_tuple(ZTilde, DTilde),
888 make_tuple(ConvDilationD, ConvStrideD)),
889 make_embed_transform(make_tuple(YTilde, HTilde),
890 make_tuple(ConvDilationH, ConvStrideH)),
891 make_embed_transform(make_tuple(XTilde, WTilde),
892 make_tuple(ConvDilationW, ConvStrideW)),
900 Sequence<7>{}));
901
902 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
904 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
906 make_freeze_transform(i_ztilde),
907 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
908 make_freeze_transform(i_ytilde),
909 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
910 make_freeze_transform(i_xtilde),
911 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
914 Sequence<1>{},
915 Sequence<2>{},
916 Sequence<3>{},
917 Sequence<4>{},
918 Sequence<5>{},
919 Sequence<6>{},
920 Sequence<7>{}),
922 Sequence<>{},
923 Sequence<1>{},
924 Sequence<>{},
925 Sequence<2>{},
926 Sequence<>{},
927 Sequence<3>{},
928 Sequence<4>{}));
929
930 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
931 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
933 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
937
938 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
939 wei_gemmk0_gemmn_gemmk1_grid_desc,
940 in_gemmm_gemmn_grid_desc);
941 }
942
943 } // function end
944
945 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
946 static auto GetABCGridDesc()
947 {
949 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
950 }
951
952 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
953 static auto GetABCGridDesc()
954 {
956 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
957 }
958
959 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
960 static auto GetABCGridDesc()
961 {
963 1,
964 1,
965 {1, 1, 1},
966 {1, 1, 1},
967 {1, 1, 1},
968 {1, 1, 1},
969 {1, 1, 1},
970 {1, 1, 1},
971 {1, 1, 1},
972 {0, 0, 0});
973 }
974
976
980
981 // GridwiseGemm
982 template <index_t NXdlPerWave_>
984 BlockSize,
985 ABDataType, // TODO: distinguish A/B datatype
986 AccDataType,
987 CDataType,
989 InElementwiseOperation,
990 WeiElementwiseOperation,
991 OutElementwiseOperation,
992 MPerBlock,
993 NPerBlock,
994 K0PerBlock,
995 MPerXDL,
996 NPerXDL,
997 K1,
998 MXdlPerWave,
999 NXdlPerWave_,
1000 ABlockTransferThreadClusterLengths_K0_M_K1,
1001 ABlockTransferThreadClusterArrangeOrder,
1002 ABlockTransferSrcAccessOrder,
1003 ABlockTransferSrcVectorDim,
1004 ABlockTransferSrcScalarPerVector,
1005 ABlockTransferDstScalarPerVector_K1,
1006 false, // AThreadTransferSrcResetCoordinateAfterRun,
1007 ABlockLdsAddExtraM,
1008 BBlockTransferThreadClusterLengths_K0_N_K1,
1009 BBlockTransferThreadClusterArrangeOrder,
1010 BBlockTransferSrcAccessOrder,
1011 BBlockTransferSrcVectorDim,
1012 BBlockTransferSrcScalarPerVector,
1013 BBlockTransferDstScalarPerVector_K1,
1014 false, // BThreadTransferSrcResetCoordinateAfterRun,
1015 BBlockLdsAddExtraN,
1016 Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
1017 7, // CThreadTransferSrcDstVectorDim,
1018 CThreadTransferDstScalarPerVector>;
1021
1022 // Argument
1023 struct Argument : public BaseArgument
1024 {
1025 Argument(InDataType* p_in_grid,
1026 const WeiDataType* p_wei_grid,
1027 const OutDataType* p_out_grid,
1028 ck::index_t N,
1029 ck::index_t K,
1030 ck::index_t C,
1031 std::vector<ck::index_t> input_spatial_lengths,
1032 std::vector<ck::index_t> filter_spatial_lengths,
1033 std::vector<ck::index_t> output_spatial_lengths,
1034 std::vector<ck::index_t> conv_filter_strides,
1035 std::vector<ck::index_t> conv_filter_dilations,
1036 std::vector<ck::index_t> input_left_pads,
1037 std::vector<ck::index_t> input_right_pads)
1038 : p_a_grid_{p_out_grid},
1039 p_b_grid_{p_wei_grid},
1040 p_c_grid_{p_in_grid},
1041 Conv_N_{N},
1042 Conv_K_{K},
1043 Conv_C_{C},
1044 input_spatial_lengths_{input_spatial_lengths},
1045 filter_spatial_lengths_{filter_spatial_lengths},
1046 output_spatial_lengths_{output_spatial_lengths},
1047 conv_filter_strides_{conv_filter_strides},
1048 conv_filter_dilations_{conv_filter_dilations},
1049 input_left_pads_{input_left_pads},
1050 input_right_pads_{input_right_pads}
1051 {
1053 }
1054
1055 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
1057 {
1058 const index_t ConvStrideW = conv_filter_strides_[0];
1059 const index_t ConvDilationW = conv_filter_dilations_[0];
1060 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1061 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1062
1063 const index_t X = filter_spatial_lengths_[0];
1064
1065 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1066 {
1067 // check slice is valid
1068 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1069 if(XDotSlice <= 0)
1070 {
1071 continue;
1072 }
1073
1074 const auto descs =
1076 Conv_N_,
1077 Conv_K_,
1078 Conv_C_,
1086 {i_xtilde});
1087 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1088 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1089 c_grid_desc_m_n_container_.push_back(descs[I2]);
1090 }
1091 }
1092 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
1094 {
1095 const index_t ConvStrideH = conv_filter_strides_[0];
1096 const index_t ConvStrideW = conv_filter_strides_[1];
1097
1098 const index_t ConvDilationH = conv_filter_dilations_[0];
1099 const index_t ConvDilationW = conv_filter_dilations_[1];
1100
1101 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1102 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1103
1104 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1105 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1106
1107 const index_t Y = filter_spatial_lengths_[0];
1108 const index_t X = filter_spatial_lengths_[1];
1109 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1110 {
1111 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1112 {
1113 // check slice is valid
1114 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1115 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1116 if(YDotSlice * XDotSlice <= 0)
1117 {
1118 continue;
1119 }
1120
1121 const auto descs =
1123 Conv_N_,
1124 Conv_K_,
1125 Conv_C_,
1133 {i_ytilde, i_xtilde});
1134 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1135 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1136 c_grid_desc_m_n_container_.push_back(descs[I2]);
1137 }
1138 }
1139 }
1140 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
1142 {
1143 const index_t ConvStrideD = conv_filter_strides_[0];
1144 const index_t ConvStrideH = conv_filter_strides_[1];
1145 const index_t ConvStrideW = conv_filter_strides_[2];
1146
1147 const index_t ConvDilationD = conv_filter_dilations_[0];
1148 const index_t ConvDilationH = conv_filter_dilations_[1];
1149 const index_t ConvDilationW = conv_filter_dilations_[2];
1150
1151 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
1152 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1153 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1154
1155 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1156 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1157 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1158
1159 const index_t Z = filter_spatial_lengths_[0];
1160 const index_t Y = filter_spatial_lengths_[1];
1161 const index_t X = filter_spatial_lengths_[2];
1162 for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1163 {
1164 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1165 {
1166 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1167 {
1168 // check slice is valid
1169 const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
1170 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1171 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1172 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1173 {
1174 continue;
1175 }
1176
1177 const auto descs =
1179 Conv_N_,
1180 Conv_K_,
1181 Conv_C_,
1189 {i_ztilde, i_ytilde, i_xtilde});
1190 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1191 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1192 c_grid_desc_m_n_container_.push_back(descs[I2]);
1193 }
1194 }
1195 }
1196 }
1197
1201 std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
1202 std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
1203 std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
1204 OutElementwiseOperation a_element_op_;
1205 WeiElementwiseOperation b_element_op_;
1206 InElementwiseOperation c_element_op_;
1207 // for checking IsSupportedArgument()
1211
1212 std::vector<ck::index_t> input_spatial_lengths_;
1213 std::vector<ck::index_t> filter_spatial_lengths_;
1214 std::vector<ck::index_t> output_spatial_lengths_;
1215 std::vector<ck::index_t> conv_filter_strides_;
1216 std::vector<ck::index_t> conv_filter_dilations_;
1217 std::vector<ck::index_t> input_left_pads_;
1218 std::vector<ck::index_t> input_right_pads_;
1219 };
1220
1221 // Invoker
1222 struct Invoker : public BaseInvoker
1223 {
1225
1226 template <typename GridwiseGemm>
1227 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1228 {
1229 float ave_time = 0;
1230 for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1231 {
1232 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1233 {
1234 std::cout << "arg.a_grid_desc_k0_m_k1{"
1235 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
1236 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
1237 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
1238 << std::endl;
1239
1240 std::cout << "arg.b_grid_desc_k0_n_k1{"
1241 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
1242 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
1243 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
1244 << std::endl;
1245
1246 std::cout << "arg.c_grid_desc_m_n{"
1247 << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
1248 << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
1249 << std::endl;
1250 }
1251
1252 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
1255 {
1256 throw std::runtime_error(
1257 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
1258 }
1259
1260 const auto [gdx, gdy, gdz] =
1261 GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
1262
1263 const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
1264 arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
1265
1266 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
1267 {
1268 const auto kernel =
1269 kernel_gemm_xdlops_v2r3<GridwiseGemm,
1270 ADataType, // TODO: distiguish A/B datatype
1271 CDataType,
1275 true>;
1276
1277 ave_time += launch_and_time_kernel(stream_config,
1278 kernel,
1279 dim3(gdx, gdy, gdz),
1280 dim3(BlockSize),
1281 0,
1282 arg.p_a_grid_,
1283 arg.p_b_grid_,
1284 arg.p_c_grid_,
1288 }
1289 else
1290 {
1291 const auto kernel =
1292 kernel_gemm_xdlops_v2r3<GridwiseGemm,
1293 ADataType, // TODO: distiguish A/B datatype
1294 CDataType,
1298 false>;
1299
1300 ave_time += launch_and_time_kernel(stream_config,
1301 kernel,
1302 dim3(gdx, gdy, gdz),
1303 dim3(BlockSize),
1304 0,
1305 arg.p_a_grid_,
1306 arg.p_b_grid_,
1307 arg.p_c_grid_,
1311 }
1312 }
1313 return ave_time;
1314 }
1315
1317
1318 float Run(const BaseArgument* p_arg,
1319 const StreamConfig& stream_config = StreamConfig{}) override
1320 {
1321 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1322 }
1323 };
1324
1325 static constexpr bool IsValidCompilationParameter()
1326 {
1327 // TODO: properly implement this check
1328 return true;
1329 }
1330
1331 static bool IsSupportedArgument(const Argument& arg)
1332 {
1334 {
1335 return false;
1336 }
1337 if constexpr(ConvBackwardDataSpecialization ==
1339 {
1340 // check if it's 1x1, stride=1 pad = 0 conv
1341 for(int i = 0; i < NDimSpatial; i++)
1342 {
1343 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1344 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1345 {
1346 return false;
1347 }
1348 }
1349 }
1350
1351 // vector load A/B matrix from global memory
1352 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 &&
1353 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
1354 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
1355 {
1356 return false;
1357 }
1358
1359 // vector store C matrix into global memory
1360 if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1361 {
1362 return false;
1363 }
1364
1365 // Gridwise GEMM size
1366 bool isWave64 = get_warp_size() == 64;
1367 for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1368 {
1369 bool valid = false;
1370 if(isWave64)
1371 {
1372 if constexpr(NXdlPerWave64 > 0)
1373 {
1377 }
1378 }
1379 else
1380 {
1381 if constexpr(NXdlPerWave32 > 0)
1382 {
1386 }
1387 }
1388 if(!valid)
1389 return false;
1390 }
1391 return true;
1392 }
1393
1394 bool IsSupportedArgument(const BaseArgument* p_arg) override
1395 {
1396 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1397 }
1398
1399 static auto MakeArgument(InDataType* p_in_grid,
1400 const WeiDataType* p_wei_grid,
1401 const OutDataType* p_out_grid,
1402 ck::index_t N,
1403 ck::index_t K,
1404 ck::index_t C,
1405 std::vector<ck::index_t> input_spatial_lengths,
1406 std::vector<ck::index_t> filter_spatial_lengths,
1407 std::vector<ck::index_t> output_spatial_lengths,
1408 std::vector<ck::index_t> conv_filter_strides,
1409 std::vector<ck::index_t> conv_filter_dilations,
1410 std::vector<ck::index_t> input_left_pads,
1411 std::vector<ck::index_t> input_right_pads)
1412 {
1413 return Argument{p_in_grid,
1414 p_wei_grid,
1415 p_out_grid,
1416 N,
1417 K,
1418 C,
1419 input_spatial_lengths,
1420 filter_spatial_lengths,
1421 output_spatial_lengths,
1422 conv_filter_strides,
1423 conv_filter_dilations,
1424 input_left_pads,
1425 input_right_pads};
1426 }
1427
1428 static auto MakeInvoker() { return Invoker{}; }
1429
1430 std::unique_ptr<BaseArgument>
1431 MakeArgumentPointer(void* p_in_grid,
1432 const void* p_wei_grid,
1433 const void* p_out_grid,
1434 ck::index_t N,
1435 ck::index_t K,
1436 ck::index_t C,
1437 std::vector<ck::index_t> input_spatial_lengths,
1438 std::vector<ck::index_t> filter_spatial_lengths,
1439 std::vector<ck::index_t> output_spatial_lengths,
1440 std::vector<ck::index_t> conv_filter_strides,
1441 std::vector<ck::index_t> conv_filter_dilations,
1442 std::vector<ck::index_t> input_left_pads,
1443 std::vector<ck::index_t> input_right_pads,
1444 InElementwiseOperation,
1445 WeiElementwiseOperation,
1446 OutElementwiseOperation) override
1447 {
1448 return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
1449 static_cast<const WeiDataType*>(p_wei_grid),
1450 static_cast<const OutDataType*>(p_out_grid),
1451 N,
1452 K,
1453 C,
1454 input_spatial_lengths,
1455 filter_spatial_lengths,
1456 output_spatial_lengths,
1457 conv_filter_strides,
1458 conv_filter_dilations,
1459 input_left_pads,
1460 input_right_pads);
1461 }
1462
1463 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1464 {
1465 return std::make_unique<Invoker>(Invoker{});
1466 }
1467
1468 std::string GetTypeString() const override
1469 {
1470 auto str = std::stringstream();
1471
1472 // clang-format off
1473 str << "DeviceConvNdBwdDataNwcKxcNwk_Xdl"
1474 << "<"
1475 << BlockSize << ", "
1476 << MPerBlock << ", "
1477 << NPerBlock << ", "
1478 << K0PerBlock << ", "
1479 << K1 << ", "
1480 << MXdlPerWave << ", "
1481 << NXdlPerWave << ", "
1482 << ABlockTransferSrcScalarPerVector << ", "
1483 << ABlockTransferDstScalarPerVector_K1 << ", "
1484 << BBlockTransferSrcScalarPerVector << ", "
1485 << BBlockTransferDstScalarPerVector_K1
1486 << ">";
1487 if constexpr(ConvBackwardDataSpecialization ==
1489
1490 str<< " Filter1x1Stride1Pad0";
1491 }
1492
1493
1494 return str.str();
1495 }
1496};
1497
1498} // namespace device
1499} // namespace tensor_operation
1500} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_conv_bwd_data.hpp:25
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1024
WeiElementwiseOperation b_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1205
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1025
InElementwiseOperation c_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1206
std::vector< ck::index_t > conv_filter_dilations_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1216
std::vector< ck::index_t > input_left_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1217
std::vector< ck::index_t > input_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1212
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1202
index_t Conv_N_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1208
CDataType * p_c_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1200
index_t Conv_C_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1210
std::vector< ck::index_t > input_right_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1218
const BDataType * p_b_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1199
OutElementwiseOperation a_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1204
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1203
index_t Conv_K_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1209
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1201
void CreateABCDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1056
std::vector< ck::index_t > filter_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1213
std::vector< ck::index_t > output_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1214
const ADataType * p_a_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1198
std::vector< ck::index_t > conv_filter_strides_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1215
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1223
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1318
DeviceOp::Argument Argument
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1224
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1227
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:80
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:978
static constexpr auto I6
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:100
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:975
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:115
static auto GetABCGridDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:946
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1463
static constexpr auto I5
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:99
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:84
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:979
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1019
static auto MakeInvoker()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1428
static constexpr auto I7
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:101
static constexpr auto GemmK1Number
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:111
InDataType ABDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:92
static constexpr auto NXdlPerWave32
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:85
static constexpr auto I3
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:97
std::string GetTypeString() const override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1468
static constexpr auto I2
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:96
InDataType CDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:89
static constexpr auto I0
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1431
static constexpr bool IsValidCompilationParameter()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1325
OutDataType ADataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:87
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1394
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:977
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1020
DeviceConvNdBwdDataNwcKxcNwk_Xdl DeviceOp
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:81
static constexpr auto K1Number
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:110
static bool IsSupportedArgument(const Argument &arg)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1331
WeiDataType BDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:88
static constexpr auto I1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:95
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:983
static constexpr auto I4
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:98
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1399
#define CK_ENV(name)
Definition utility/env.hpp:129