gridwise_gemm_xdlops_v3r3.hpp Source File

gridwise_gemm_xdlops_v3r3.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_v3r3.hpp Source File
gridwise_gemm_xdlops_v3r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename GridwiseGemm,
21 typename FloatAB,
22 typename FloatC,
23 typename AGridDesc_K0_M_K1,
24 typename BGridDesc_K0_N_K1,
25 typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26 typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27 typename C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CElementwiseOperation,
31 typename Block2CTileMap,
32 bool HasMainKBlockLoop>
33__global__ void
34#if CK_USE_LAUNCH_BOUNDS
36#endif
38 const FloatAB* __restrict__ p_a_grid,
39 const FloatAB* __restrict__ p_b_grid,
40 FloatC* __restrict__ p_c_grid,
41 const FloatC* __restrict__ p_c0_grid,
42 const FloatC* __restrict__ p_c1_grid,
43 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
44 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
45 const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47 const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
48 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
49 const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
50 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
51 const AElementwiseOperation a_element_op,
52 const BElementwiseOperation b_element_op,
53 const CElementwiseOperation c_element_op,
54 const Block2CTileMap block_2_ctile_map)
55{
56#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
57 defined(__gfx12__)
58 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
59 {
60 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
61
62 GridwiseGemm::template Run<HasMainKBlockLoop>(
63 p_a_grid,
64 p_b_grid,
65 p_c_grid,
66 p_c0_grid,
67 p_c1_grid,
68 p_shared,
69 a_grid_desc_k0_m_k1,
70 b_grid_desc_k0_n_k1,
71 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
72 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
73 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
74 a_element_op,
75 b_element_op,
76 c_element_op,
77 block_2_ctile_map);
78 }
79#else
80 ignore = p_a_grid;
81 ignore = p_b_grid;
82 ignore = p_c_grid;
83 ignore = p_c0_grid;
84 ignore = p_c1_grid;
85 ignore = a_grid_desc_k0_m_k1;
86 ignore = b_grid_desc_k0_n_k1;
87 ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
88 ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
89 ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
90 ignore = a_element_op;
91 ignore = b_element_op;
92 ignore = c_element_op;
93 ignore = block_2_ctile_map;
94#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
95}
96
97template <
98 index_t BlockSize,
99 typename FloatAB,
100 typename FloatAcc,
101 typename FloatC,
102 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
103 typename AGridDesc_K0_M_K1,
104 typename BGridDesc_K0_N_K1,
105 typename CGridDesc_M_N,
106 typename C0GridDesc_M_N,
107 typename C1GridDesc_M_N,
108 typename AElementwiseOperation,
109 typename BElementwiseOperation,
110 typename CElementwiseOperation,
111 index_t MPerBlock,
112 index_t NPerBlock,
113 index_t K0PerBlock,
114 index_t MPerXdl,
115 index_t NPerXdl,
116 index_t K1Value,
117 index_t MXdlPerWave,
118 index_t NXdlPerWave,
119 typename ABlockTransferThreadClusterLengths_K0_M_K1,
120 typename ABlockTransferThreadClusterArrangeOrder,
121 typename ABlockTransferSrcAccessOrder,
122 index_t ABlockTransferSrcVectorDim,
123 index_t ABlockTransferSrcScalarPerVector,
124 index_t ABlockTransferDstScalarPerVector_K1,
125 bool AThreadTransferSrcResetCoordinateAfterRun,
126 bool ABlockLdsExtraM,
127 typename BBlockTransferThreadClusterLengths_K0_N_K1,
128 typename BBlockTransferThreadClusterArrangeOrder,
129 typename BBlockTransferSrcAccessOrder,
130 index_t BBlockTransferSrcVectorDim,
131 index_t BBlockTransferSrcScalarPerVector,
132 index_t BBlockTransferDstScalarPerVector_K1,
133 bool BThreadTransferSrcResetCoordinateAfterRun,
134 bool BBlockLdsExtraN,
135 index_t CShuffleMXdlPerWavePerShuffle,
136 index_t CShuffleNXdlPerWavePerShuffle,
137 typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
138 index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
139 index_t NumGemmKPrefetchStage = 1,
142{
143 static constexpr auto I0 = Number<0>{};
144 static constexpr auto I1 = Number<1>{};
145 static constexpr auto I2 = Number<2>{};
146 static constexpr auto I3 = Number<3>{};
147 static constexpr auto I4 = Number<4>{};
148 static constexpr auto I5 = Number<5>{};
149 static constexpr auto I6 = Number<6>{};
150 static constexpr auto I7 = Number<7>{};
151
152 // K1 should be Number<...>
153 static constexpr auto K1 = Number<K1Value>{};
154
156
159
160 __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
161 {
162 constexpr auto max_lds_align = K1;
163
164 // A matrix in LDS memory, dst of blockwise copy
165 constexpr auto a_block_desc_k0_m_k1 = [&]() {
166 if constexpr(ABlockLdsExtraM)
167 {
171 }
172 else
173 {
175 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
176 }
177 }();
178
179 return a_block_desc_k0_m_k1;
180 }
181
182 __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
183 {
184 constexpr auto max_lds_align = K1;
185
186 // B matrix in LDS memory, dst of blockwise copy
187 constexpr auto b_block_desc_k0_n_k1 = [&]() {
188 if constexpr(BBlockLdsExtraN)
189 {
193 }
194 else
195 {
197 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
198 }
199 }();
200
201 return b_block_desc_k0_n_k1;
202 }
203
204 __host__ __device__ static constexpr auto
206 {
207 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
208 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
209
210 constexpr auto
211 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
216 I1,
219
220 return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
221 }
222
223 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
224 {
225 // LDS allocation for A and B: be careful of alignment
226 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
227
228 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
229
230 constexpr auto max_lds_align = K1;
231
232 constexpr auto a_block_space_size_aligned =
233 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
234
235 constexpr auto b_block_space_size_aligned =
236 math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
237
238 // LDS allocation for C shuffle in LDS
239 constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
241
242 constexpr auto c_block_size =
243 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
244 .GetElementSpaceSize();
245
246 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
247 sizeof(FloatAB),
248 c_block_size * sizeof(FloatC));
249 }
250
251 template <
252 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
253 __device__ static bool constexpr IsValidCompilationParameter()
254 {
255 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
256 BlockSize,
257 MPerBlock,
258 NPerBlock,
259 MPerXdl,
260 NPerXdl,
261 MXdlPerWave,
262 NXdlPerWave,
263 FloatC,
264 CGlobalMemoryDataOperation>();
265 }
266
267 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
268 template <typename Block2CTileMap>
269 __host__ __device__ static constexpr bool
270 CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
271 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
272 const CGridDesc_M_N& c_grid_desc_m_n,
273 const Block2CTileMap& block_2_ctile_map)
274 {
275 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
276 "wrong! K1 need to be known at compile-time");
277
278 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
279 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
280 "Invalid tuning param!");
281
282 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
283 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
284 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
285
286 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
287 K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
288 K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
289 return false;
290
291 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
292 return false;
293
294 // check gridwise gemm pipeline
295 const auto num_k_loop = K0 / K0PerBlock;
296
297 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
298 {
299 return false;
300 }
301
302 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
303 {
304 return false;
305 }
306
307 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
308 return true;
309 }
310
311 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
312 {
313 const index_t num_loop = K / (K0PerBlock * K1);
314
315 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
316 }
317
318 template <typename CGridDesc_M_N_>
319 __host__ __device__ static constexpr auto
321 const CGridDesc_M_N_& c_grid_desc_m_n)
322 {
323 const auto M = c_grid_desc_m_n.GetLength(I0);
324 const auto N = c_grid_desc_m_n.GetLength(I1);
325
326 const auto MBlock = M / MPerBlock;
327 const auto NBlock = N / NPerBlock;
328
329 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
330 constexpr index_t NWave =
331 NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl);
332
333 const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
335 c_grid_desc_m_n,
342
343 return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
344 }
345
346 // return block_id to C matrix tile idx (m0, n0) mapping
347 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
348 const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
349 {
351 c_grid_desc_m_n);
352 }
356 CGridDesc_M_N{}))>;
357
361 C0GridDesc_M_N{}))>;
362
366 C1GridDesc_M_N{}))>;
367
369 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
370
371 template <bool HasMainKBlockLoop, typename Block2CTileMap>
372 __device__ static void
373 Run(const FloatAB* __restrict__ p_a_grid,
374 const FloatAB* __restrict__ p_b_grid,
375 FloatC* __restrict__ p_c_grid,
376 const FloatC* __restrict__ p_c0_grid,
377 const FloatC* __restrict__ p_c1_grid,
378 void* __restrict__ p_shared,
379 const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
380 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
382 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
384 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
386 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
387 const AElementwiseOperation& a_element_op,
388 const BElementwiseOperation& b_element_op,
389 const CElementwiseOperation& c_element_op,
390 const Block2CTileMap& block_2_ctile_map)
391 {
392 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
393 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
394 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
395 p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
397 p_c_grid,
398 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
399 .GetElementSpaceSize());
401 p_c0_grid,
402 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
403 .GetElementSpaceSize());
405 p_c1_grid,
406 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
407 .GetElementSpaceSize());
408
409 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
410
411 // divide block work by [M, N]
412 const auto block_work_idx =
413 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
414
415 if(!block_2_ctile_map.ValidCTileIndex(
416 block_work_idx,
418 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
419 .GetLength(I0),
420 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
421 .GetLength(I3))))
422 {
423 return;
424 }
425
426 // HACK: this force m/n_block_data_idx_on_grid into SGPR
427 const index_t m_block_data_idx_on_grid =
428 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
429
430 const index_t n_block_data_idx_on_grid =
431 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
432
433 // lds max alignment
434 constexpr auto max_lds_align = K1;
435
436 // A matrix in LDS memory, dst of blockwise copy
437 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
438
439 // B matrix in LDS memory, dst of blockwise copy
440 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
441
442 // A matrix blockwise copy
443 auto a_blockwise_copy =
445 AElementwiseOperation,
449 ABlockTransferThreadClusterLengths_K0_M_K1,
450 ABlockTransferThreadClusterArrangeOrder,
451 FloatAB,
452 FloatAB,
453 decltype(a_grid_desc_k0_m_k1),
454 decltype(a_block_desc_k0_m_k1),
455 ABlockTransferSrcAccessOrder,
457 ABlockTransferSrcVectorDim,
458 2,
459 ABlockTransferSrcScalarPerVector,
460 ABlockTransferDstScalarPerVector_K1,
461 1,
462 1,
463 AThreadTransferSrcResetCoordinateAfterRun,
464 true>(
465 a_grid_desc_k0_m_k1,
466 make_multi_index(0, m_block_data_idx_on_grid, 0),
467 a_element_op,
468 a_block_desc_k0_m_k1,
469 make_multi_index(0, 0, 0),
471
472 // B matrix blockwise copy
473 auto b_blockwise_copy =
475 BElementwiseOperation,
479 BBlockTransferThreadClusterLengths_K0_N_K1,
480 BBlockTransferThreadClusterArrangeOrder,
481 FloatAB,
482 FloatAB,
483 decltype(b_grid_desc_k0_n_k1),
484 decltype(b_block_desc_k0_n_k1),
485 BBlockTransferSrcAccessOrder,
487 BBlockTransferSrcVectorDim,
488 2,
489 BBlockTransferSrcScalarPerVector,
490 BBlockTransferDstScalarPerVector_K1,
491 1,
492 1,
493 BThreadTransferSrcResetCoordinateAfterRun,
494 true>(
495 b_grid_desc_k0_n_k1,
496 make_multi_index(0, n_block_data_idx_on_grid, 0),
497 b_element_op,
498 b_block_desc_k0_n_k1,
499 make_multi_index(0, 0, 0),
501
502 // GEMM definition
503 // c_mtx += transpose(a_mtx) * b_mtx
504 // a_mtx[K0PerBlock, MPerBlock] is in LDS
505 // b_mtx[K0PerBlock, NPerBlock] is in LDS
506 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
507 // register
508 // sanity check
509
510 auto blockwise_gemm =
512 FloatAB,
513 FloatAB,
514 FloatAcc,
515 decltype(a_block_desc_k0_m_k1),
516 decltype(b_block_desc_k0_n_k1),
517 MPerXdl,
518 NPerXdl,
519 MXdlPerWave,
520 NXdlPerWave,
521 K1>{};
522
523 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
524
525 // LDS allocation for A and B: be careful of alignment
526 constexpr auto a_block_space_size_aligned =
527 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
528
530 static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
531
533 static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
534 b_block_desc_k0_n_k1.GetElementSpaceSize());
535
536 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
537 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
538
539 // gridwise GEMM pipeline
540 const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
541
542 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
543 a_block_desc_k0_m_k1,
544 a_blockwise_copy,
545 a_grid_buf,
546 a_block_buf,
547 a_block_slice_copy_step,
548 b_grid_desc_k0_n_k1,
549 b_block_desc_k0_n_k1,
550 b_blockwise_copy,
551 b_grid_buf,
552 b_block_buf,
553 b_block_slice_copy_step,
554 blockwise_gemm,
555 c_thread_buf,
556 K0BlockMainLoop);
557
558 // shuffle C and write out
559 {
560 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
561 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
562 "wrong!");
563
564 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
565 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
566
567 // TODO: hacky, fix it!
568 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
569 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
570
571 // TODO: hacky, fix it!
572 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
573 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
574 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
575
576 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
577 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
578 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
579 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
580 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
581 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
582 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
583 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
584
585 constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
587
589 static_cast<FloatC*>(p_shared),
590 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
591 .GetElementSpaceSize());
592
593 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
594 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
595 make_tuple(make_freeze_transform(I0), // freeze mblock
597 Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per
598 // shuffle
600 make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
601 make_freeze_transform(I0), // freeze nblock
603 Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per
604 // shuffle
606 make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
608 Sequence<1>{},
609 Sequence<2>{},
610 Sequence<3>{},
611 Sequence<4>{},
612 Sequence<5>{}),
614 Sequence<0>{},
616 Sequence<>{},
617 Sequence<1>{},
619
620 );
621
622 // calculate origin of thread output tensor on global memory
623 // blockwise GEMM c matrix starting index
624 const auto c_thread_mtx_on_block =
625 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
626
627 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
628 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
629
630 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
632 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
635
636 const auto m_thread_data_on_block_idx =
637 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
638 make_multi_index(m_thread_data_on_block));
639
640 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
645
646 const auto n_thread_data_on_block_idx =
647 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
648 make_multi_index(n_thread_data_on_block));
649
650 // VGPR to LDS
651 auto c_thread_copy_vgpr_to_lds =
653 FloatC,
654 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
655 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
657 Sequence<CShuffleMXdlPerWavePerShuffle,
658 CShuffleNXdlPerWavePerShuffle,
659 I1,
660 I1,
661 M2,
662 I1,
663 M4,
664 I1>,
666 7,
667 1,
669 1,
670 true>{
671 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
673 0,
674 m_thread_data_on_block_idx[I1],
675 n_thread_data_on_block_idx[I1],
676 m_thread_data_on_block_idx[I2],
677 m_thread_data_on_block_idx[I3],
678 m_thread_data_on_block_idx[I4],
679 n_thread_data_on_block_idx[I2]),
681
682 auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
683 ThisThreadBlock, // ThreadGroup
684 CElementwiseOperation, // ElementwiseOperation,
685 CGlobalMemoryDataOperation, // DstInMemOp,
686 Sequence<1,
687 CShuffleMXdlPerWavePerShuffle,
688 MWave * MPerXdl,
689 1,
690 CShuffleNXdlPerWavePerShuffle,
691 NWave * NPerXdl>, // BlockSliceLengths,
692 CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
693 Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
694 FloatC, // typename Src0Data,
695 FloatC, // typename Src1Data,
696 FloatC, // typename Src2Data,
697 FloatC, // typename DstData,
698 decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
699 decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
700 decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
701 decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
702 Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
703 5, // index_t VectorDim,
704 CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
705 true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
706 false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
707 false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
708 false> // bool ThreadTransferDstResetCoordinateAfterRun>
709 {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
710 make_multi_index(0, 0, 0, 0, 0, 0),
711 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
712 make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
713 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
714 make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
715 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
716 make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
717 c_element_op};
718
719 constexpr auto mxdlperwave_forward_step =
720 make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
721 constexpr auto nxdlperwave_forward_step =
722 make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
723 constexpr auto nxdlperwave_backward_step =
724 make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
725
727 constexpr auto mxdlperwave = mxdlperwave_iter;
728
729 static_for<0,
730 NXdlPerWave,
731 CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
732 constexpr bool nxdlperwave_forward_sweep =
733 (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
734
735 constexpr index_t nxdlperwave_value =
736 nxdlperwave_forward_sweep
737 ? nxdlperwave_iter
738 : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
739
740 constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
741
742 // make sure it's safe to do ds_write
744
745 // VGPR to LDS
746 c_thread_copy_vgpr_to_lds.Run(
747 c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
748 make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
749 c_thread_buf,
750 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
751 c_block_buf);
752
753 // make sure it's safe to do ds_read
755
756 // LDS to global
757 c_block_copy_lds_to_global.Run(
758 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
759 c_block_buf,
760 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
761 c0_grid_buf,
762 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
763 c1_grid_buf,
764 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
765 c_grid_buf);
766
767 // move on nxdlperwave dimension
768 if constexpr(nxdlperwave_forward_sweep &&
769 (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
770 {
771 c_block_copy_lds_to_global.MoveSrc1SliceWindow(
772 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
773 nxdlperwave_forward_step);
774
775 c_block_copy_lds_to_global.MoveSrc2SliceWindow(
776 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
777 nxdlperwave_forward_step);
778
779 c_block_copy_lds_to_global.MoveDstSliceWindow(
780 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
781 nxdlperwave_forward_step);
782 }
783 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
784 {
785 c_block_copy_lds_to_global.MoveSrc1SliceWindow(
786 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
787 nxdlperwave_backward_step);
788
789 c_block_copy_lds_to_global.MoveSrc2SliceWindow(
790 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
791 nxdlperwave_backward_step);
792
793 c_block_copy_lds_to_global.MoveDstSliceWindow(
794 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
795 nxdlperwave_backward_step);
796 }
797 });
798
799 // move on mxdlperwave dimension
800 if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
801 {
802 c_block_copy_lds_to_global.MoveSrc1SliceWindow(
803 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
804 mxdlperwave_forward_step);
805
806 c_block_copy_lds_to_global.MoveSrc2SliceWindow(
807 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
808 mxdlperwave_forward_step);
809
810 c_block_copy_lds_to_global.MoveDstSliceWindow(
811 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
812 mxdlperwave_forward_step);
813 }
814 });
815 }
816 }
817};
818
819} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_xdlops_v3r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_v3r3.hpp:37
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_smfmac_xdlops.hpp:44
Definition gridwise_gemm_xdlops_v3r3.hpp:142
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const CDataType *__restrict__ p_c0_grid, const CDataType *__restrict__ p_c1_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const InElementwiseOperation &a_element_op, const WeiElementwiseOperation &b_element_op, const OutElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdlops_v3r3.hpp:373
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r3.hpp:40
Definition threadwise_tensor_slice_transfer.hpp:39
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340