gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp Source File

gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp Source File#

Composable Kernel: gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp Source File
gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
19
20namespace ck {
21
22// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
23// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
24template <typename ADataType,
25 typename B0DataType,
26 typename Acc0DataType,
27 typename B1DataType,
28 typename Acc1DataType,
29 typename CShuffleDataType,
30 typename CDataType,
31 typename AElementwiseOperation,
32 typename B0ElementwiseOperation,
33 typename AccElementwiseOperation,
34 typename B1ElementwiseOperation,
35 typename CElementwiseOperation,
36 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
37 typename AGridDesc,
38 typename B0GridDesc,
39 typename B1GridDesc,
40 typename CGridDesc_M_N,
41 index_t MPerBlock,
42 index_t LPerBlock,
43 index_t KPerBlock,
44 index_t AK1Value,
45 index_t BK1Value,
46 index_t NPerBlock,
47 index_t LTilePerBlock,
48 index_t L1Value,
49 index_t MPerWmma,
50 index_t LPerWmma,
51 index_t NPerWmma,
52 index_t MRepeat,
53 index_t LRepeat,
54 index_t NRepeat,
55 index_t BlockSize,
56 typename ABlockTransferThreadClusterLengths_K0_M_K1,
57 typename ABlockTransferThreadClusterArrangeOrder,
58 typename ABlockTransferSrcAccessOrder,
59 index_t ABlockTransferSrcVectorDim,
60 index_t ABlockTransferSrcScalarPerVector,
61 index_t ABlockTransferDstScalarPerVector_K1,
62 bool AThreadTransferSrcResetCoordinateAfterRun,
63 bool AEnableLds,
64 bool ABlockLdsExtraM,
65 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
66 typename B0BlockTransferThreadClusterArrangeOrder,
67 typename B0BlockTransferSrcAccessOrder,
68 index_t B0BlockTransferSrcVectorDim,
69 index_t B0BlockTransferSrcScalarPerVector,
70 index_t B0BlockTransferDstScalarPerVector_K1,
71 bool B0ThreadTransferSrcResetCoordinateAfterRun,
72 bool B0EnableLds,
73 bool B0BlockLdsExtraL,
74 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
75 typename B1BlockTransferThreadClusterArrangeOrder,
76 typename B1BlockTransferSrcAccessOrder,
77 index_t B1BlockTransferSrcVectorDim,
78 index_t B1BlockTransferSrcScalarPerVector,
79 index_t B1BlockTransferDstScalarPerVector_L1,
80 bool B1ThreadTransferSrcResetCoordinateAfterRun,
81 bool B1EnableLds,
82 bool B1BlockLdsExtraN,
83 index_t CShuffleMRepeatPerShuffle,
84 index_t CShuffleNRepeatPerShuffle,
85 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
86 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
87 bool PadN,
89 index_t NumGemmKPrefetchStage = 1,
93{
94 static constexpr auto I0 = Number<0>{};
95 static constexpr auto I1 = Number<1>{};
96 static constexpr auto I2 = Number<2>{};
97 static constexpr auto I3 = Number<3>{};
98 static constexpr auto I4 = Number<4>{};
99 static constexpr auto I5 = Number<5>{};
100 static constexpr auto I6 = Number<6>{};
101 static constexpr auto I7 = Number<7>{};
102
103 static constexpr auto AK1 = Number<AK1Value>{};
104 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
105 static constexpr auto BK1 = Number<BK1Value>{};
106
107 static constexpr auto L0PerBlock = LTilePerBlock / L1Value;
108 static constexpr auto AL0 = Number<L0PerBlock / 2>{};
109 static constexpr auto AL1 = Number<L1Value>{};
110 static constexpr auto BL0 = Number<L0PerBlock>{};
111 static constexpr auto BL1 = Number<L1Value>{};
112
113 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
114 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
115 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
116 static constexpr auto WmmaK = 16;
117 static constexpr auto WmmaL = 16;
118
120
123 NumGemmKPrefetchStage,
124 LoopSched,
125 AEnableLds,
126 B0EnableLds>())>;
127
128 __host__ __device__ static constexpr auto MakeABlockDescriptor()
129 {
130 constexpr auto a_block_desc = [&]() {
131 if constexpr(AEnableLds)
132 {
133 // K0->M->K1 Per Block
134 constexpr auto K0PerBlock = KPerBlock / AK1;
135 constexpr auto max_lds_align = AK1;
136
137 if constexpr(ABlockLdsExtraM)
138 {
142 }
143 else
144 {
146 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, AK1), max_lds_align);
147 }
148 }
149 else
150 {
151 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
152 constexpr auto K0PerWmma = WmmaK / 2 / AK1;
153 // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
157 I1,
159 I1,
160 I1,
161 AK1),
165 AK1,
166 AK1,
167 AK1,
168 I1));
169 }
170 }();
171
172 return a_block_desc;
173 }
174
175 __host__ __device__ static constexpr auto MakeB0BlockDescriptor()
176 {
177 constexpr auto b0_block_desc = [&]() {
178 if constexpr(B0EnableLds)
179 {
180 // K0->L->BK1 Per Block
181 constexpr auto K0PerBlock = KPerBlock / BK1;
182 constexpr auto max_lds_align = BK1;
183
184 if constexpr(B0BlockLdsExtraL)
185 {
189 }
190 else
191 {
193 make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1), max_lds_align);
194 }
195 }
196 else
197 {
198 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
199 constexpr auto K0PerWmma = WmmaK / 2 / BK1;
200 // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
204 I1,
206 I1,
207 I1,
208 BK1),
212 BK1,
213 BK1,
214 BK1,
215 I1));
216 }
217 }();
218
219 return b0_block_desc;
220 }
221
222 __host__ __device__ static constexpr auto MakeB1BlockDescriptor()
223 {
224 constexpr auto b1_block_desc = [&]() {
225 if constexpr(B1EnableLds)
226 {
227 // L0->N->BL1 Per Block
228 constexpr auto max_lds_align = BL1;
229
230 if constexpr(B1BlockLdsExtraN)
231 {
235 }
236 else
237 {
239 make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1), max_lds_align);
240 }
241 }
242 else
243 {
244 constexpr auto LWmmaPerblock = LPerBlock / WmmaL;
245 constexpr auto L0PerWmma = WmmaL / 2 / BL1;
246 // LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread
250 I1,
252 I1,
253 I1,
254 BL1),
258 BL1,
259 BL1,
260 BL1,
261 I1));
262 }
263 }();
264
265 return b1_block_desc;
266 }
267
268 __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
269 {
270 constexpr auto a_block_copy_step = [&]() {
271 if constexpr(AEnableLds)
272 {
273 constexpr auto K0PerBlock = KPerBlock / AK1;
274
275 return make_multi_index(K0PerBlock, 0, 0);
276 }
277 else
278 {
279 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
280
281 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
282 }
283 }();
284
285 return a_block_copy_step;
286 }
287
288 __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep()
289 {
290 constexpr auto b0_block_copy_step = [&]() {
291 if constexpr(B0EnableLds)
292 {
293 constexpr auto K0PerBlock = KPerBlock / BK1;
294
295 return make_multi_index(K0PerBlock, 0, 0);
296 }
297 else
298 {
299 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
300
301 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
302 }
303 }();
304
305 return b0_block_copy_step;
306 }
307
308 __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep()
309 {
310 constexpr auto b1_block_copy_step = [&]() {
311 if constexpr(B1EnableLds)
312 {
313 return make_multi_index(L0PerBlock, 0, 0);
314 }
315 else
316 {
317 constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
318
319 return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0);
320 }
321 }();
322
323 return b1_block_copy_step;
324 }
325
326 // Describe how data read from (LDS/VGPR) buffer
327 template <typename ABlockDesc_>
328 __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
329 {
330
331 constexpr auto a_wave_desc = [&]() {
332 if constexpr(AEnableLds)
333 {
334 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
335 constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
336 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
337 constexpr auto A_KRow = I1;
339 ABlockDesc_{},
346 }
347 else
348 {
349 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
350 constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
351 constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
352 constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
353 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
354
357 I1,
359 I1,
360 Number<A_K1>{}));
361 }
362 }();
363
364 return a_wave_desc;
365 }
366
367 template <typename B0BlockDesc_>
368 __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
369 {
370
371 constexpr auto b0_wave_desc = [&]() {
372 if constexpr(B0EnableLds)
373 {
374 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
375 constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
376 constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
377#ifdef __gfx12__
378 constexpr auto B_KRow = I2;
379#else
380 constexpr auto B_KRow = I1;
381#endif
383 B0BlockDesc_{},
390 }
391 else
392 {
393 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
394 constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0);
395 constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3);
396 constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4);
397 constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6);
398
399 // Workaround, Freeze transform
402 I1,
404 I1,
405 Number<B_K1>{}));
406 }
407 }();
408
409 return b0_wave_desc;
410 }
411
412 template <typename A1BlockDesc_AL0_M_AL1>
413 __host__ __device__ static constexpr auto
414 MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
415 {
416 constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
417 constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
418 constexpr auto A_LRow = I1;
420 A1BlockDesc_AL0_M_AL1{},
426 }
427
428 template <typename B1BlockDesc_>
429 __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&)
430 {
431
432 constexpr auto b1_wave_desc = [&]() {
433 if constexpr(B1EnableLds)
434 {
435 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
436 constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
437 constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
438#ifdef __gfx12__
439 constexpr auto B_LRow = I2;
440#else
441 constexpr auto B_LRow = I1;
442#endif
444 B1BlockDesc_{},
451 }
452 else
453 {
454 constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0);
455 constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3);
456 constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4);
457 constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6);
458
461 I1,
463 I1,
464 Number<B_L1>{}));
465 }
466 }();
467
468 return b1_wave_desc;
469 }
470
471 __host__ __device__ static constexpr auto
472 // *Caution Here repeat is shuffle repeat
474 {
475 constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
476 constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
477
478 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
482 I1,
484
485 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
486 }
487
488 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
489 {
490 // LDS allocation for A and B: be careful of alignment
491 const index_t gemm0_bytes_end =
492 (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) +
494
495 const index_t gemm1_bytes_end =
498
499 const index_t softmax_bytes_end =
502
503 const index_t c_block_bytes_end =
504 SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType);
505
506 return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
507 }
508
509 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
510 template <typename Block2CTileMap>
511 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
512 const B0GridDesc& b0_grid_desc,
513 const B1GridDesc& b1_grid_desc,
514 const CGridDesc_M_N& c_grid_desc_m_n,
515 const Block2CTileMap& block_2_ctile_map)
516 {
517 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
518 (LPerBlock % (LPerWmma * LRepeat)) == 0,
519 "Invalid tuning param!");
520
521 const auto GetAProblemsizeMK = [&]() {
522 if constexpr(AEnableLds)
523 {
524 return make_tuple(a_grid_desc.GetLength(I1),
525 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
526 }
527 else
528 {
529 return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
530 a_grid_desc.GetLength(I5),
531 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
532 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
533 }
534 };
535
536 const auto GetB0ProblemsizeLK = [&]() {
537 if constexpr(B0EnableLds)
538 {
539 return make_tuple(b0_grid_desc.GetLength(I1),
540 b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2));
541 }
542 else
543 {
544 return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) *
545 b0_grid_desc.GetLength(I5),
546 b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) *
547 b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6));
548 }
549 };
550
551 const auto GetB1ProblemsizeNL = [&]() {
552 if constexpr(B1EnableLds)
553 {
554 return make_tuple(b1_grid_desc.GetLength(I1),
555 b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2));
556 }
557 else
558 {
559 return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) *
560 b1_grid_desc.GetLength(I5),
561 b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) *
562 b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6));
563 }
564 };
565
566 const auto M = GetAProblemsizeMK()[I0];
567 const auto L = GetB0ProblemsizeLK()(I0);
568 const auto K = GetAProblemsizeMK()[I1];
569 const auto N = GetB1ProblemsizeNL()(I0);
570
571 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
572 {
573 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
574 {
575 printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n",
576 M,
577 N,
578 c_grid_desc_m_n.GetLength(I0),
579 c_grid_desc_m_n.GetLength(I1));
580 }
581 return false;
582 }
583
584 if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
585 {
586 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
587 {
588 printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | "
589 "M/L/K/NPerBlock = "
590 "%d, %d, %d, %d\n",
591 M,
592 L,
593 K,
594 N,
595 MPerBlock,
596 LPerBlock,
597 KPerBlock,
598 NPerBlock);
599 }
600 return false;
601 }
602
603 // check gemm0 gridwise gemm pipeline
604 const auto num_gemm0_k_loop = K / KPerBlock;
605 if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
606 {
607 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
608 {
609 printf("GridwiseOp: outer loop unsupport\n");
610 }
611 return false;
612 }
613
614 // check gemm1 gridwise gemm pipeline
615 if(!(LPerBlock % LTilePerBlock == 0))
616 {
617 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
618 {
619 printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n",
620 LPerBlock,
621 LTilePerBlock);
622 }
623 return false;
624 }
625
626 const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock;
627 if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
628 {
629 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
630 {
631 printf("GridwiseOp: inner loop unsupport\n");
632 }
633 return false;
634 }
635
636 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
637 {
638 return false;
639 }
640
641 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
642 return true;
643 }
644
645 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
646 {
647 const index_t num_loop = math::integer_divide_ceil(K, KPerBlock);
648
649 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
650 }
651
652 __host__ __device__ static constexpr auto
653 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
654 {
655 const auto M = c_grid_desc_m_n.GetLength(I0);
656 const auto N = c_grid_desc_m_n.GetLength(I1);
657
658 const auto MBlock = M / MPerBlock;
659 const auto NBlock = N / NPerBlock;
660
661 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
662 c_grid_desc_m_n,
667
668 return c_grid_desc_mblock_mperblock_nblock_nperblock;
669 }
670
671 // return block_id to C matrix tile idx (m0, n0) mapping
672 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
673 const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
674 {
676 c_grid_desc_m_n);
677 }
678
681 CGridDesc_M_N{}))>;
683 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
684
686 {
687 // LDS allocation for A and B: be careful of alignment
688 static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
689
690 static constexpr auto a_block_space_size_aligned =
691 AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
693 : 0;
694 static constexpr auto b0_block_space_size_aligned =
695 B0EnableLds ? math::integer_least_multiple(
696 MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align)
697 : 0;
698 static constexpr auto b1_block_space_size_aligned =
699 B1EnableLds ? math::integer_least_multiple(
700 MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align)
701 : 0;
702
703 static constexpr auto a_block_space_offset = 0;
705 static constexpr auto b1_block_space_offset = 0;
706
707 // LDS allocation for reduction
708 // Feature to add, IntraThread Reduction
711
712 static constexpr auto reduction_space_offset = 0;
713
714 // LDS allocation for C shuffle in LDS
715 static constexpr auto c_block_space_size =
717 .GetElementSpaceSize();
718 };
719
720 template <bool HasMainKBlockLoop,
721 typename C0MatrixMask,
722 typename Block2CTileMap = DefaultBlock2CTileMap>
723 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
724 const B0DataType* __restrict__ p_b0_grid,
725 const B1DataType* __restrict__ p_b1_grid,
726 CDataType* __restrict__ p_c_grid,
727 void* __restrict__ p_shared,
728 const AGridDesc& a_grid_desc,
729 const B0GridDesc& b0_grid_desc,
730 const B1GridDesc& b1_grid_desc,
732 c_grid_desc_mblock_mperblock_nblock_nperblock,
733 const AElementwiseOperation& a_element_op,
734 const B0ElementwiseOperation& b0_element_op,
735 const AccElementwiseOperation& acc_element_op,
736 const B1ElementwiseOperation& b1_element_op,
737 const CElementwiseOperation& c_element_op,
738 const C0MatrixMask& c0_matrix_mask,
739 const Block2CTileMap& block_2_ctile_map)
740 {
741 // clang-format off
742/*******************************************************************************/
743// Memory buffer zone.
744 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
745 p_a_grid, a_grid_desc.GetElementSpaceSize());
746 const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
747 p_b0_grid, b0_grid_desc.GetElementSpaceSize());
748 const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
749 p_b1_grid, b1_grid_desc.GetElementSpaceSize());
751 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
752
753/*******************************************************************************/
754// BlockIdx.x -> [BlockId.m, BlockId.n]
755 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
756 if(!block_2_ctile_map.ValidCTileIndex(
757 block_work_idx,
758 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
759 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
760 { return; }
761
762 // Store BlockId into SGPR
763 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
764 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
765
766/*******************************************************************************/
767// set up Gemm0
768/*******************************************************************************/
769
770/*******************************************************************************/
771// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
772 constexpr auto a_block_desc = MakeABlockDescriptor();
773 constexpr auto b0_block_desc = MakeB0BlockDescriptor();
774
775 auto a_block_trait = [&](){
776 // A matrix blockwise copy
777 if constexpr(AEnableLds)
778 {
779 constexpr auto AK0PerBlock = KPerBlock/ AK1;
781 static_cast<ADataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
783
784 auto a_blockwise_copy =
786/* typename SrcElementwiseOperation, */ AElementwiseOperation,
787/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
788/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
789/* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>,
790/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
791/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
792/* typename SrcData, */ ADataType,
793/* typename DstData, */ ADataType,
794/* typename SrcDesc, */ decltype(a_grid_desc),
795/* typename DstDesc, */ decltype(a_block_desc),
796/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
797/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
798/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
799/* index_t DstVectorDim, */ 2,
800/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
801/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
802/* index_t SrcScalarStrideInVector, */ 1,
803/* index_t DstScalarStrideInVector, */ 1,
804/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
805/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
806 NumGemmKPrefetchStage>(
807 a_grid_desc,
808 make_multi_index(0, m_block_data_idx_on_grid, 0),
809 a_element_op,
810 a_block_desc,
811 make_multi_index(0, 0, 0),
813
814 return make_tuple(a_block_buf, a_blockwise_copy);
815 }
816 else
817 {
818 // Thread-wise copy
819 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
820 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
821 constexpr auto K0PerWmma = WmmaK/2/AK1Value;
823 a_block_desc.GetElementSpaceSize());
824
825 // Limitation: NumDim of Src and Dst descriptor should be identical
826 auto a_blockwise_copy =
828 ADataType,
829 decltype(a_grid_desc),
830 decltype(a_block_desc),
833 I1,
835 I1,
836 I1,
839 6,
840 ABlockTransferSrcScalarPerVector,
841 AThreadTransferSrcResetCoordinateAfterRun,
842 true>(
843 a_grid_desc,
845 m_block_data_idx_on_grid/(MWaves * MPerWmma),
847 0,
848 (get_thread_local_1d_id() % 32 )/ 16,
850 0));
851
852 return make_tuple(a_block_buf, a_blockwise_copy);
853 }
854 };
855
856 auto b0_block_trait = [&](){
857 if constexpr(B0EnableLds)
858 {
860 static_cast<B0DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
862
863 auto b0_blockwise_copy =
865 B0ElementwiseOperation,
869 B0BlockTransferThreadClusterLengths_K0_L_K1,
870 B0BlockTransferThreadClusterArrangeOrder,
871 B0DataType,
872 B0DataType,
873 decltype(b0_grid_desc),
874 decltype(b0_block_desc),
875 B0BlockTransferSrcAccessOrder,
877 B0BlockTransferSrcVectorDim,
878 2,
879 B0BlockTransferSrcScalarPerVector,
880 B0BlockTransferDstScalarPerVector_K1,
881 1,
882 1,
883 B0ThreadTransferSrcResetCoordinateAfterRun,
884 true,
885 NumGemmKPrefetchStage>(
886 b0_grid_desc,
887 make_multi_index(0, 0, 0),
888 b0_element_op,
889 b0_block_desc,
890 make_multi_index(0, 0, 0),
892
893 return make_tuple(b0_block_buf, b0_blockwise_copy);
894 }
895 else
896 {
897 // Thread-wise copy
898 // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
899 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
900 constexpr auto K0PerWmma = WmmaK/2/BK1Value;
902 b0_block_desc.GetElementSpaceSize());
903
904 // Limitation: NumDim of Src and Dst descriptor should be identical
905 auto b0_blockwise_copy =
907 B0DataType,
908 decltype(b0_grid_desc),
909 decltype(b0_block_desc),
912 I1,
914 I1,
915 I1,
918 6,
919 B0BlockTransferSrcScalarPerVector,
920 B0ThreadTransferSrcResetCoordinateAfterRun,
921 true>(
922 b0_grid_desc,
924 0/(LWaves * LPerWmma),
926 0,
927 (get_thread_local_1d_id() % 32 )/ 16,
929 0));
930
931 return make_tuple(b0_block_buf, b0_blockwise_copy);
932 }
933 };
934
935 auto a_block_buf = a_block_trait()[I0];
936 auto a_blockwise_copy = a_block_trait()[I1];
937
938 auto b0_block_buf = b0_block_trait()[I0];
939 auto b0_blockwise_copy = b0_block_trait()[I1];
940
941/*******************************************************************************/
942 // Gemm0
943 constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK);
944
945 auto blockwise_gemm0 = BlockwiseGemmWMMA<
946 BlockSize,
947 ADataType,
948 B0DataType,
949 Acc0DataType,
950 decltype(MakeAWaveDescriptor(a_block_desc)),
951 decltype(MakeB0WaveDescriptor(b0_block_desc)),
952 MPerBlock,
953 LPerBlock,
954 KPerBlock,
955 MPerWmma,
956 LPerWmma,
957 MRepeat,
958 LRepeat,
959 KPack,
960 AEnableLds,
961 B0EnableLds,
962 true>{}; // C' = B' x A'
963
964
965 // Prepare Register for A*B0 matrix
966 auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
967
968 constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
969 blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
970
971 constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
972 constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
973 constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
974 constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
975 constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
976 constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
977 constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
978
979 constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
980 acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
982 make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
983 make_pass_through_transform(laccvgprs)),
986
987/*******************************************************************************/
988 // Shift Per SUB_K
989 constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
990 constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep();
991
992 const auto a_block_reset_copy_step = [&](){
993 if constexpr(AEnableLds){
994 return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0);
995 }
996 else{
997 return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0);
998 }
999 }();
1000
1001 const auto b0_block_reset_copy_step = [&](){
1002 if constexpr(B0EnableLds){
1003 return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0);
1004 }
1005 else{
1006 return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0);
1007 }
1008 }();
1009
1010 const auto K = [&](){
1011 if constexpr(AEnableLds){
1012 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
1013 }
1014 else{
1015 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
1016 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
1017 }
1018 }();
1019
1020 const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
1021/*******************************************************************************/
1022// softmax
1023/*******************************************************************************/
1024 auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1025 static_cast<Acc0DataType*>(p_shared) + SharedMemTrait::reduction_space_offset,
1027 // get acc0 7D thread cluster
1028 constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
1029 blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() /
1030 blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
1031 constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0);
1032 constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1);
1033 constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2);
1034 constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3);
1035 constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
1036 constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
1037 constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
1038 // get acc0 thread map
1039 constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
1040 make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
1044 constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor(
1045 make_tuple(
1047 make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))),
1050 const auto threadid_to_l_n_thread_cluster_adaptor =
1051 chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor);
1052
1053 // get acc0 2D thread cluster & 2D thread slice
1054 constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed(
1055 make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs));
1056
1057 constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed(
1058 make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs));
1059
1060 auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
1061 Acc0DataType,
1062 decltype(threadid_to_l_n_thread_cluster_adaptor),
1063 decltype(thread_cluster_desc_m_l),
1064 decltype(thread_slice_desc_m_l)>{};
1065
1066 // Initialize running sum and max of exponentiating row vectors
1067 using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
1068 SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
1069 running_sum = 0;
1070 running_sum_new = 0;
1072 running_max_new = NumericLimits<Acc0DataType>::Lowest();
1073/*******************************************************************************/
1074// set up Gemm1
1075/*******************************************************************************/
1076 // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
1077 // A1 matrix in VGPR
1078 constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
1082
1083 constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0];
1084 constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1];
1085 constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2];
1086
1087 constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor(
1088 make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1),
1089 make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1));
1090
1091 // A1 matrix blockwise copy
1092 auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
1093 Acc0DataType,
1094 ADataType,
1095 decltype(acc0_thread_desc_l0perblock_mperblock_l1),
1096 decltype(a1_thread_desc_l0perblock_mperblock_l1),
1100 2,
1102
1104 a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
1105
1106 constexpr auto b1_block_desc = MakeB1BlockDescriptor();
1107
1108 auto b1_block_trait = [&](){
1109 if constexpr(B1EnableLds)
1110 {
1112 static_cast<B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
1114
1115 auto b1_blockwise_copy =
1117/* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
1118/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
1119/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
1120/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
1121/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
1122/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
1123/* typename SrcData, */ B1DataType,
1124/* typename DstData, */ B1DataType,
1125/* typename SrcDesc, */ decltype(b1_grid_desc),
1126/* typename DstDesc, */ decltype(b1_block_desc),
1127/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
1128/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
1129/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
1130/* index_t DstVectorDim, */ 2,
1131/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector,
1132/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1,
1133/* index_t SrcScalarStrideInVector, */ 1,
1134/* index_t DstScalarStrideInVector, */ 1,
1135/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
1136/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
1137 NumGemmKPrefetchStage>(
1138 b1_grid_desc,
1139 make_multi_index(0, n_block_data_idx_on_grid, 0),
1140 b1_element_op,
1141 b1_block_desc,
1142 make_multi_index(0, 0, 0),
1144
1145 return make_tuple(b1_block_buf, b1_blockwise_copy);
1146 }
1147 else
1148 {
1149 // Thread-wise copy
1150 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
1151 constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
1152 constexpr auto L0PerWmma = WmmaL/2/L1Value;
1154 b1_block_desc.GetElementSpaceSize());
1155
1156 // Limitation: NumDim of Src and Dst descriptor should be identical
1157 auto b1_blockwise_copy =
1159 B1DataType,
1160 decltype(b1_grid_desc),
1161 decltype(b1_block_desc),
1164 I1,
1166 I1,
1167 I1,
1168 Number<L1Value>{}>,
1170 6,
1171 B1BlockTransferSrcScalarPerVector,
1172 B1ThreadTransferSrcResetCoordinateAfterRun,
1173 true>(
1174 b1_grid_desc,
1176 n_block_data_idx_on_grid/(NWaves * NPerWmma),
1178 0,
1179 (get_thread_local_1d_id() % 32 )/ 16,
1181 0));
1182
1183 return make_tuple(b1_block_buf, b1_blockwise_copy);
1184 }
1185 };
1186
1187 auto b1_block_buf = b1_block_trait()[I0];
1188 auto b1_blockwise_copy = b1_block_trait()[I1];
1189
1190 constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep();
1191
1192 auto blockwise_gemm1 =
1193 BlockwiseGemmWMMA<BlockSize,
1194 ADataType,
1195 B1DataType,
1196 Acc1DataType,
1197 decltype(MakeA1WaveDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
1198 decltype(MakeB1WaveDescriptor(b1_block_desc)),
1199 MPerBlock,
1200 NPerBlock,
1201 LTilePerBlock,
1202 MPerWmma,
1203 NPerWmma,
1204 MRepeat,
1205 NRepeat,
1206 KPack,
1207 false,
1208 B1EnableLds,
1209 true>{make_tuple(0, 0, 0, 0, 0, 0)};
1210
1211 auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
1212
1213 const auto L = [&](){
1214 if constexpr(B0EnableLds){
1215 return b0_grid_desc.GetLength(I1);
1216 }
1217 else{
1218 return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5);
1219 }
1220 }();
1221
1222 const index_t num_gemm1_l_block_outer_loop = L / LPerBlock;
1223 constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock;
1224
1225 // Initialize C
1226 StaticBuffer<AddressSpaceEnum::Vgpr, Acc1DataType, acc1_thread_buf.Size(), true> c_thread_buf;
1227 c_thread_buf.Clear();
1228
1229/*******************************************************************************/
1230 //
1231 // Kernel Main Stage
1232 //
1233 // Flash Attention
1234 // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
1235 index_t gemm1_l_block_outer_index = 0;
1236 // Outer loop, along GEMM_L
1237 // Inner loop, along GEMM_K
1238 do{
1239 auto l_block_data_idx_on_grid =
1240 __builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock);
1241 if(c0_matrix_mask.IsTileSkippable(
1242 m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock))
1243 {
1244 continue;
1245 }
1246 // gemm0 start, A-B swaped
1247 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
1248 a_block_desc,
1249 a_blockwise_copy,
1250 a_grid_buf,
1251 a_block_buf,
1252 a_block_slice_copy_step,
1253 b0_grid_desc,
1254 b0_block_desc,
1255 b0_blockwise_copy,
1256 b0_grid_buf,
1257 b0_block_buf,
1258 b0_block_slice_copy_step,
1259 blockwise_gemm0,
1260 acc0_thread_buf,
1261 KBlockMainLoop);
1262 // do MNK padding or upper triangular masking
1263 if constexpr(MaskOutUpperTriangle || PadN)
1264 {
1265 // 7d thread_desc in thread scope
1266 constexpr auto c_thread_lengths =
1267 blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
1268
1269 // 7d block_desc in block scope
1270 constexpr auto c_block_lengths =
1271 blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
1272
1273 constexpr auto MREPEAT = c_block_lengths[I0];
1274 constexpr auto MWAVE = c_block_lengths[I1];
1275 constexpr auto MTHREADSubGroup = c_block_lengths[I2];
1276 constexpr auto LREPEAT = c_block_lengths[I3];
1277 constexpr auto LWAVE = c_block_lengths[I4];
1278 constexpr auto LSUBGROUP = c_block_lengths[I5];
1279 constexpr auto LACCVGPRS = c_block_lengths[I6];
1280
1281 // works like multi-dimension static_for (static_ford), but provides both the linear
1282 // index as well as n-d index
1283 using Acc0TileIterator = SpaceFillingCurve<
1284 decltype(c_thread_lengths),
1285 typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
1286 typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
1287 false>; // SnakeCurved
1288
1289 auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D(
1290 Number<0>{}, Number<0>{});
1291
1292 constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor(
1293 make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)),
1294 make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))),
1297
1298 static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
1299 auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
1300 auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
1301 auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
1302 auto m_global = m_local + m_block_data_idx_on_grid;
1303 auto l_global = l_local + l_block_data_idx_on_grid;
1304 if(c0_matrix_mask.IsMaskedElement(m_global, l_global))
1305 {
1306 acc0_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
1307 }
1308 else
1309 {
1310 acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]);
1311 }
1312 });
1313 }
1314 else
1315 { static_for<0, acc0_thread_buf.Size(), 1>{}(
1316 [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
1317 }
1318
1320 // Tiled softmax start
1321 // softmax
1322 SoftmaxBuf& max = blockwise_softmax.max_value_buf;
1323 SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
1324
1325 blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
1326
1327 // TODO: may convert to log domain
1328 running_max_new = mathext::max(max, running_max);
1329 running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
1330 mathext::exp(max - running_max_new) * sum;
1331
1332 // gemm1
1333 {
1334 // TODO: explore using dynamic buffer for a1 thread buffer
1335 // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
1336 // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
1337 // the A1 source buffer is static buffer holding the output of first GEMM and
1338 // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
1339 // explicitly in Run() below.
1340
1341 // Initialize acc1
1342 acc1_thread_buf.Clear();
1343
1344 // preload data into LDS
1345 b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
1346
1347 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
1348 b1_block_slice_copy_step);
1349
1350 block_sync_lds(); // wait for reduction LDS read
1351
1352 b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
1353
1354 // main body
1355 if constexpr(num_gemm1_l_block_inner_loop > 1)
1356 {
1357 static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
1358 // Data cast from Acc0DataType to ADataType happen here
1359 a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
1361 acc0_thread_buf,
1362 a1_thread_desc_l0perblock_mperblock_l1,
1363 make_tuple(I0, I0, I0),
1364 a1_thread_buf);
1365
1366 b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
1367
1369
1370 blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
1371
1373
1374 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
1375 b1_block_slice_copy_step);
1376
1377 b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
1378 });
1379 }
1380 // tail
1381 {
1382 a1_blockwise_copy.Run(
1383 acc0_thread_desc_l0perblock_mperblock_l1,
1384 make_tuple(
1385 Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0),
1386 acc0_thread_buf,
1387 a1_thread_desc_l0perblock_mperblock_l1,
1388 make_tuple(I0, I0, I0),
1389 a1_thread_buf);
1390
1392
1393 blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
1394 }
1395 } // end gemm1
1396
1397 constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
1398 blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
1399 constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
1400 constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
1401 constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
1402 constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
1403 constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
1404 constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
1405 constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
1406
1407 constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
1408 make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
1409 c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs));
1410 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
1411 constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
1412
1415 auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
1416 Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
1417 Acc1DataType c = c_thread_buf[I]; // O
1418 Acc1DataType c_new =
1419 (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
1420 math::exp(max[iM] - running_max_new[iM]) * acc1) /
1421 running_sum_new[iM];
1422
1423 c_thread_buf(I) = c_new; // O_new
1424 });
1425 });
1426
1427 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
1428 a_block_reset_copy_step); // rewind K
1429 b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc,
1430 b0_block_reset_copy_step); // rewind K and step N
1431
1432 // update before next j iteration
1433 running_max = running_max_new;
1434 running_sum = running_sum_new;
1435
1436 block_sync_lds(); // wait for gemm1 LDS read
1437 }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop);
1438/*******************************************************************************/
1439 // write out to C, implement shuffle
1440 {
1441 constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
1442 blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
1443
1444 // This API Provide All dimension (size) you need
1445 constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
1446 blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
1447
1448 constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
1449 constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
1450 constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
1451 constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
1452 constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
1453
1454 // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
1455 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1457
1458 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1459 static_cast<CShuffleDataType*>(p_shared),
1460 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
1461
1462 constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
1463 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1464 make_tuple(
1467 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
1468 MWave, // MWave
1469 MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma
1470 )),
1473 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
1474 NWave, // NWave
1475 NSubGroup,
1476 NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma
1479
1480 // calculate origin of thread output tensor on global memory
1481 // blockwise GEMM c matrix starting index
1482 const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0);
1483
1484 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1485 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1486
1487 const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor =
1489 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))),
1492
1493 const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor =
1495 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))),
1498
1499 const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex(
1500 make_multi_index(m_thread_data_on_block));
1501
1502 const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex(
1503 make_multi_index(n_thread_data_on_block));
1504
1505 // shuffle: threadwise copy C from VGPR to LDS
1506 auto c_thread_copy_vgpr_to_lds =
1508 CShuffleDataType,
1509 decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
1510 decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
1512 Sequence<CShuffleMRepeatPerShuffle,
1513 I1,
1514 I1,
1515 CShuffleNRepeatPerShuffle,
1516 I1,
1517 I1,
1518 NAccVgprs>,
1520 6,
1521 8, // vector write pixel
1523 1,
1524 true>{
1525 c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1527 m_thread_data_on_block_idx[I1],
1528 m_thread_data_on_block_idx[I2],
1529 0,
1530 n_thread_data_on_block_idx[I1],
1531 n_thread_data_on_block_idx[I2],
1532 n_thread_data_on_block_idx[I3]),
1534
1535 // shuffle: blockwise copy C from LDS to global
1536 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1537 ThisThreadBlock, // ThreadGroup
1538 CElementwiseOperation, // ElementwiseOperation,
1539 CGlobalMemoryDataOperation, // DstInMemOp,
1540 Sequence<1,
1541 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1542 1,
1543 CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
1544 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1545 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1546 CShuffleDataType, // typename SrcData,
1547 CDataType, // typename DstData,
1548 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
1549 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1550 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1551 3, // index_t VectorDim,
1552 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1553 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1554 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1555 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1556 make_multi_index(0, 0, 0, 0),
1557 c_grid_desc_mblock_mperblock_nblock_nperblock,
1558 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1559 c_element_op};
1560
1561 // space filling curve for local reg & global memory
1562 // space filling curve for threadwise C in VGPR
1563 constexpr auto sfc_c_vgpr =
1566 Sequence<CShuffleMRepeatPerShuffle,
1567 1,
1568 1,
1569 CShuffleNRepeatPerShuffle,
1570 1,
1571 1,
1572 NAccVgprs>>{};
1573
1574 // space filling curve for shuffled blockwise C in global mem
1575 constexpr auto sfc_c_global =
1578 Sequence<1,
1579 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1580 1,
1581 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1582
1583 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1584
1585 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1586
1587 static_for<0, num_access, 1>{}([&](auto access_id) {
1588 // make sure it's safe to write to LDS
1590
1591 // each thread write its data from VGPR to LDS
1592 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1593 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1594 c_thread_buf,
1595 c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1596 c_shuffle_block_buf);
1597
1598 // make sure it's safe to read from LDS
1600
1601 // each block copy its data from LDS to global
1602 c_shuffle_block_copy_lds_to_global.Run(
1603 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1604 c_shuffle_block_buf,
1605 c_grid_desc_mblock_mperblock_nblock_nperblock,
1606 c_grid_buf);
1607
1608 if constexpr(access_id < num_access - 1)
1609 {
1610 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1611 // move on C
1612 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1613 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1614 }
1615 });
1616 }
1617 // clang-format on
1618 }
1619};
1620
1621} // namespace ck
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
__host__ __device__ constexpr auto exp(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:124
__host__ __device__ constexpr auto max(const Tuple< Xs... > &x, const Y &y)
Definition statically_indexed_array_multi_index.hpp:134
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_wmma.hpp:550
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_wmma.hpp:585
Blockwise softmax.
Definition blockwise_softmax.hpp:32
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:686
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::b1_block_space_offset
static constexpr auto b1_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:705
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::reduction_space_offset
static constexpr auto reduction_space_offset
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:712
static constexpr auto max_lds_align
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:688
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::c_block_space_size
static constexpr auto c_block_space_size
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:715
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::b1_block_space_size_aligned
static constexpr auto b1_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:698
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::b0_block_space_size_aligned
static constexpr auto b0_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:694
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::reduction_space_size_aligned
static constexpr index_t reduction_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:709
static constexpr auto b0_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:704
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::SharedMemTrait::a_block_space_size_aligned
static constexpr auto a_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:690
static constexpr auto a_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:703
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const B0ElementwiseOperation &b0_element_op, const AccElementwiseOperation &acc_element_op, const B1ElementwiseOperation &b1_element_op, const CElementwiseOperation &c_element_op, const C0MatrixMask &c0_matrix_mask, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:723
__host__ static __device__ constexpr auto MakeB1BlockDescriptor()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:222
__host__ static __device__ constexpr auto MakeABlockSliceCopyStep()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:268
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::WmmaK
static constexpr auto WmmaK
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:116
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::ThisThreadBlock
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:119
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I5
static constexpr auto I5
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:99
__host__ static __device__ constexpr auto MakeB0BlockDescriptor()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:175
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MWaves
static constexpr auto MWaves
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:113
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::AK1
static constexpr auto AK1
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:103
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I2
static constexpr auto I2
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:96
__host__ static __device__ constexpr auto MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1 &)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:414
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::BK0
static constexpr auto BK0
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:104
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::BK1
static constexpr auto BK1
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:105
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I4
static constexpr auto I4
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:98
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:682
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::L0PerBlock
static constexpr auto L0PerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:107
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::NWaves
static constexpr auto NWaves
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:115
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::BL0
static constexpr auto BL0
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:110
__host__ static __device__ constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_ &)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:368
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I0
static constexpr auto I0
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:94
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:488
__host__ static __device__ constexpr auto MakeB0BlockSliceCopyStep()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:288
__host__ static __device__ constexpr auto MakeABlockDescriptor()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:128
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::AL0
static constexpr auto AL0
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:108
__host__ static __device__ constexpr auto MakeB1BlockSliceCopyStep()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:308
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::WmmaL
static constexpr auto WmmaL
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:117
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I1
static constexpr auto I1
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:95
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I6
static constexpr auto I6
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:100
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::LWaves
static constexpr auto LWaves
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:114
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:473
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::GridwiseGemmPipe
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched, AEnableLds, B0EnableLds >())> GridwiseGemmPipe
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:121
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::AL1
static constexpr auto AL1
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:109
__host__ static __device__ constexpr auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:328
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:645
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I3
static constexpr auto I3
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:97
__host__ static __device__ constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_ &)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:429
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::I7
static constexpr auto I7
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:101
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::BL1
static constexpr auto BL1
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:111
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
__host__ static __device__ constexpr T Lowest()
Definition numeric_limits.hpp:312
__host__ static __device__ constexpr T Infinity()
Definition numeric_limits.hpp:317
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/sequence.hpp:256
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/sequence.hpp:289
#define CK_ENV(name)
Definition utility/env.hpp:129