blockwise_gemm_pipeline_xdlops_v4.hpp Source File

blockwise_gemm_pipeline_xdlops_v4.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v4.hpp Source File
blockwise_gemm_pipeline_xdlops_v4.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimimal pipeline with highest resource request
11// GlobalPrefetchStages: 3
12// LocalPreFillStages: 2
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 2
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
37{
38};
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::I1;
123 using Base::KRepeat;
124 using Base::xdlops_gemm;
125 using typename Base::HotLoopInstList;
126
138
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144
146
147 static constexpr index_t PrefetchStages = 3;
148 static constexpr index_t PrefillStages = 2;
149 static constexpr index_t GlobalBufferNum = 1;
150 static constexpr index_t HotloopUnroll = 2;
151
152 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
153 {
154 return num_loop > PrefetchStages;
155 }
156
157 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158 {
159 if(num_loop % HotloopUnroll == 1)
160 {
161 return TailNumber::Odd;
162 }
163 else
164 {
165 return TailNumber::Even;
166 }
167 }
168
169 __device__ static constexpr void HotLoopScheduler()
170 {
171 // TODO: Take data type into consideration as pipe ver 3
172 // A-B splited schedule
173 constexpr auto num_ds_read_inst_a =
174 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
177 constexpr auto num_ds_read_inst_b =
178 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
181
182 constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
183 constexpr auto num_dswrite_per_issue_a =
184 (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
185 constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
186
187 constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
188 constexpr auto num_dswrite_per_issue_b =
189 (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
190 constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
191
192 constexpr auto num_mfma_per_issue =
193 HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
194
195 static_for<0, num_issue_a, 1>{}([&](auto i) {
196 ignore = i;
197 static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
198 ignore = idsread;
199 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
200 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
201 });
202
203 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
204 ignore = idswrite;
205 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
206 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
207 });
208
209 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
210 __builtin_amdgcn_sched_group_barrier(0x008,
211 num_mfma_per_issue - num_dsread_per_issue_a -
212 num_dswrite_per_issue_a,
213 0); // MFMA
214 });
215
216 static_for<0, num_issue_b, 1>{}([&](auto i) {
217 ignore = i;
218 static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
219 ignore = idsread;
220 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
221 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222 });
223
224 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225 ignore = idswrite;
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 });
229
230 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
231 __builtin_amdgcn_sched_group_barrier(0x008,
232 num_mfma_per_issue - num_dsread_per_issue_a -
233 num_dswrite_per_issue_b,
234 0); // MFMA
235 });
236 __builtin_amdgcn_sched_barrier(0);
237 }
238
239 template <bool HasMainLoop,
240 TailNumber TailNum,
241 typename AGridDesc,
242 typename ABlockDesc,
243 typename ABlockTransfer,
244 typename AGridBuffer,
245 typename ABlockBuffer,
246 typename ABlockTransferStep,
247 typename BGridDesc,
248 typename BBlockDesc,
249 typename BBlockTransfer,
250 typename BGridBuffer,
251 typename BBlockBuffer,
252 typename BBlockTransferStep,
253 typename CThreadBuffer>
254 __device__ void Run(const AGridDesc& a_grid_desc,
255 const ABlockDesc& a_block_desc,
256 ABlockTransfer& a_blockwise_copy,
257 const AGridBuffer& a_grid_buf,
258 ABlockBuffer& a_block_buf,
259 const ABlockTransferStep& a_block_copy_step,
260 const BGridDesc& b_grid_desc,
261 const BBlockDesc& b_block_desc,
262 BBlockTransfer& b_blockwise_copy,
263 const BGridBuffer& b_grid_buf,
264 BBlockBuffer& b_block_buf,
265 const BBlockTransferStep& b_block_copy_step,
266 CThreadBuffer& c_thread_buf,
267 index_t num_loop) const
268 {
270 a_thread_desc_.GetElementSpaceSize());
272 b_thread_desc_.GetElementSpaceSize());
273
274 StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
275 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
276
277 // Global prefetch 1
278 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
279 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
280
281 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
282 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
283
284 // Local prefill 1
285 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
286 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
287
288 // Local prefetch 1
290 static_for<0, KRepeat, 1>{}([&](auto k) {
291 static_for<0, MRepeat, 1>{}([&](auto m0) {
294 a_block_buf.At(I0),
296 make_tuple(m0, I0, k, I0),
297 a_thread_bufs(I0));
298 });
299 static_for<0, NRepeat, 1>{}([&](auto n0) {
302 b_block_buf.At(I0),
304 make_tuple(n0, I0, k, I0),
305 b_thread_bufs(I0));
306 });
307 });
308
309 // Global prefetch 2
310 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
311 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
312
313 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
314 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
315
316 // Local prefill 2
317 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
318 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
319
320 // Global prefetch 3
321 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
322 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
323
324 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
325 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
326
327 // Initialize C
328 c_thread_buf.Clear();
329
330 // main body
331 if constexpr(HasMainLoop)
332 {
333 index_t i = 0;
334 // This hot loop has two legacy loopover, to implement the double local buffer strategy
335 do
336 {
337 auto LoopFunc = [&](auto lds_read_buf,
338 auto lds_read_reg_buf,
339 auto lds_write_buf,
340 auto mfma_reg_buf) {
342
343 static_for<0, KRepeat, 1>{}([&](auto k) {
344 static_for<0, MRepeat, 1>{}([&](auto m0) {
347 a_block_buf.At(lds_read_buf),
349 make_tuple(m0, I0, k, I0),
350 a_thread_bufs(lds_read_reg_buf));
351 });
352 static_for<0, NRepeat, 1>{}([&](auto n0) {
355 b_block_buf.At(lds_read_buf),
357 make_tuple(n0, I0, k, I0),
358 b_thread_bufs(lds_read_reg_buf));
359 });
360 });
361
362 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
363 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
364
365 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
366 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
367
368 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
369 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
370
371 static_for<0, KRepeat, 1>{}([&](auto k0) {
372 static_for<0, MRepeat, 1>{}([&](auto m0) {
373 static_for<0, NRepeat, 1>{}([&](auto n0) {
376
377 static_for<0, KPack, 1>{}([&](auto ik) {
378 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
379 a_thread_bufs[mfma_reg_buf]
380 [Number<a_thread_desc_.CalculateOffset(
381 make_tuple(m0, I0, k0, ik))>{}];
382 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
383 b_thread_bufs[mfma_reg_buf]
384 [Number<b_thread_desc_.CalculateOffset(
385 make_tuple(n0, I0, k0, ik))>{}];
386 });
387
388 using mfma_input_type =
390 xdlops_gemm.K1PerXdlops>::type;
391
392 constexpr index_t c_offset =
393 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
394
395 xdlops_gemm.Run(
396 a_thread_vec.template AsType<mfma_input_type>(),
397 b_thread_vec.template AsType<mfma_input_type>(),
398 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
399 });
400 });
401 });
402
404 };
405
406 LoopFunc(I1, I1, I0, I0);
407 LoopFunc(I0, I0, I1, I1);
408
409 i += HotloopUnroll;
410 } while(i < (num_loop - PrefetchStages));
411 }
412
413 auto ReadWriteCompFunc = [&](auto lds_read_buf,
414 auto lds_read_reg_buf,
415 auto lds_write_buf,
416 auto mfma_reg_buf) {
418
419 static_for<0, KRepeat, 1>{}([&](auto k) {
420 static_for<0, MRepeat, 1>{}([&](auto m0) {
423 a_block_buf.At(lds_read_buf),
425 make_tuple(m0, I0, k, I0),
426 a_thread_bufs(lds_read_reg_buf));
427 });
428 static_for<0, NRepeat, 1>{}([&](auto n0) {
431 b_block_buf.At(lds_read_buf),
433 make_tuple(n0, I0, k, I0),
434 b_thread_bufs(lds_read_reg_buf));
435 });
436 });
437
438 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
439 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
440
441 static_for<0, KRepeat, 1>{}([&](auto k0) {
442 static_for<0, MRepeat, 1>{}([&](auto m0) {
443 static_for<0, NRepeat, 1>{}([&](auto n0) {
446
447 static_for<0, KPack, 1>{}([&](auto ik) {
448 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
449 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
450 make_tuple(m0, I0, k0, ik))>{}];
451 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
452 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
453 make_tuple(n0, I0, k0, ik))>{}];
454 });
455
456 using mfma_input_type =
457 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
458
459 constexpr index_t c_offset =
460 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
461
462 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
463 b_thread_vec.template AsType<mfma_input_type>(),
464 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
465 });
466 });
467 });
468
470 };
471
472 auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
474
475 static_for<0, KRepeat, 1>{}([&](auto k) {
476 static_for<0, MRepeat, 1>{}([&](auto m0) {
479 a_block_buf.At(lds_read_buf),
481 make_tuple(m0, I0, k, I0),
482 a_thread_bufs(lds_read_reg_buf));
483 });
484 static_for<0, NRepeat, 1>{}([&](auto n0) {
487 b_block_buf.At(lds_read_buf),
489 make_tuple(n0, I0, k, I0),
490 b_thread_bufs(lds_read_reg_buf));
491 });
492 });
493
494 static_for<0, KRepeat, 1>{}([&](auto k0) {
495 static_for<0, MRepeat, 1>{}([&](auto m0) {
496 static_for<0, NRepeat, 1>{}([&](auto n0) {
499
500 static_for<0, KPack, 1>{}([&](auto ik) {
501 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
502 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
503 make_tuple(m0, I0, k0, ik))>{}];
504 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
505 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
506 make_tuple(n0, I0, k0, ik))>{}];
507 });
508
509 using mfma_input_type =
510 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
511
512 constexpr index_t c_offset =
513 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
514
515 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
516 b_thread_vec.template AsType<mfma_input_type>(),
517 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
518 });
519 });
520 });
521
523 };
524
525 auto CompFunc = [&](auto mfma_reg_buf) {
526 static_for<0, KRepeat, 1>{}([&](auto k0) {
527 static_for<0, MRepeat, 1>{}([&](auto m0) {
528 static_for<0, NRepeat, 1>{}([&](auto n0) {
531
532 static_for<0, KPack, 1>{}([&](auto ik) {
533 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
534 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
535 make_tuple(m0, I0, k0, ik))>{}];
536 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
537 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
538 make_tuple(n0, I0, k0, ik))>{}];
539 });
540
541 using mfma_input_type =
542 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
543
544 constexpr index_t c_offset =
545 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
546
547 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
548 b_thread_vec.template AsType<mfma_input_type>(),
549 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
550 });
551 });
552 });
553 };
554 // tail
555 if constexpr(TailNum == TailNumber::Odd)
556 {
557 ReadWriteCompFunc(I1, I1, I0, I0);
558 ReadCompFunc(I0, I0, I1);
559 CompFunc(I0);
560 }
561 else if constexpr(TailNum == TailNumber::Even)
562 {
563 ReadCompFunc(I1, I1, I0);
564 CompFunc(I1);
565 }
566 }
567
568 protected:
569 using Base::a_thread_copy_;
570 using Base::a_thread_desc_;
571 using Base::b_thread_copy_;
572 using Base::b_thread_desc_;
573 using Base::c_thread_desc_;
574};
575
576// Compute optimimal pipeline with highest resource request
577// Implementation with direct load
578// GlobalPrefetchStages: 3
579// LocalPreFillStages: 2
580// LocalPreFetchStages: 1
581// LocalSharedMemoryBuffer: 2
582
583template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
584 index_t BlockSize,
585 typename ADataType,
586 typename BDataType,
587 typename ComputeDataType,
588 typename AccDataType,
589 typename ATileDesc,
590 typename BTileDesc,
591 typename AMmaTileDesc,
592 typename BMmaTileDesc,
593 index_t ABlockTransferSrcScalarPerVector,
594 index_t BBlockTransferSrcScalarPerVector,
595 index_t MPerBlock,
596 index_t NPerBlock,
597 index_t KPerBlock,
598 index_t MPerXDL,
599 index_t NPerXDL,
600 index_t MRepeat,
601 index_t NRepeat,
602 index_t KPacks>
606
607template <index_t BlockSize,
608 typename ADataType,
609 typename BDataType,
610 typename ComputeDataType,
611 typename AccDataType,
612 typename ATileDesc,
613 typename BTileDesc,
614 typename AMmaTileDesc,
615 typename BMmaTileDesc,
616 index_t ABlockTransferSrcScalarPerVector,
617 index_t BBlockTransferSrcScalarPerVector,
618 index_t MPerBlock,
619 index_t NPerBlock,
620 index_t KPerBlock,
621 index_t MPerXDL,
622 index_t NPerXDL,
623 index_t MRepeat,
624 index_t NRepeat,
625 index_t KPack
626 // ,bool TransposeC //disable transposec right now...
627 >
629 BlockSize,
630 ADataType,
631 BDataType,
632 ComputeDataType,
633 AccDataType,
634 ATileDesc,
635 BTileDesc,
636 AMmaTileDesc,
637 BMmaTileDesc,
638 ABlockTransferSrcScalarPerVector,
639 BBlockTransferSrcScalarPerVector,
640 MPerBlock,
641 NPerBlock,
642 KPerBlock,
643 MPerXDL,
644 NPerXDL,
645 MRepeat,
646 NRepeat,
647 KPack>
649 ADataType,
650 BDataType,
651 ComputeDataType,
652 AccDataType,
653 ATileDesc,
654 BTileDesc,
655 AMmaTileDesc,
656 BMmaTileDesc,
657 ABlockTransferSrcScalarPerVector,
658 BBlockTransferSrcScalarPerVector,
659 MPerBlock,
660 NPerBlock,
661 KPerBlock,
662 MPerXDL,
663 NPerXDL,
664 MRepeat,
665 NRepeat,
666 KPack>
667
668{
670 ADataType,
671 BDataType,
672 ComputeDataType,
673 AccDataType,
674 ATileDesc,
675 BTileDesc,
676 AMmaTileDesc,
677 BMmaTileDesc,
678 ABlockTransferSrcScalarPerVector,
679 BBlockTransferSrcScalarPerVector,
680 MPerBlock,
681 NPerBlock,
682 KPerBlock,
683 MPerXDL,
684 NPerXDL,
685 MRepeat,
686 NRepeat,
687 KPack>;
688 using Base::I0;
689 using Base::I1;
690 using Base::KRepeat;
691 using Base::xdlops_gemm;
692 using typename Base::HotLoopInstList;
693
705
708
709 using Base::AMmaKStride;
710 using Base::BMmaKStride;
711
713
714 static constexpr index_t PrefetchStages = 2;
715 static constexpr index_t PrefillStages = 2;
716 static constexpr index_t GlobalBufferNum = 1;
717 static constexpr index_t HotloopUnroll = 2;
718
719 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
720 {
721 return num_loop > PrefetchStages;
722 }
723
724 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
725 {
726 if(num_loop % HotloopUnroll == 1)
727 {
728 return TailNumber::Odd;
729 }
730 else
731 {
732 return TailNumber::Even;
733 }
734 }
735
736 __device__ static constexpr void HotLoopScheduler()
737 {
738 // TODO: Take data type into consideration as pipe ver 3
739 // A-B splited schedule
740 constexpr auto num_ds_read_inst_a =
741 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
744 constexpr auto num_ds_read_inst_b =
745 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
748
749 constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
750 constexpr auto num_dswrite_per_issue_a = 0;
751 constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
752
753 constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
754 constexpr auto num_dswrite_per_issue_b = 0;
755 constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
756
757 constexpr auto num_mfma_per_issue =
758 HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
759
760 static_for<0, num_issue_a, 1>{}([&](auto i) {
761 ignore = i;
762 static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
763 ignore = idsread;
764 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
765 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
766 });
767
768 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
769 ignore = idswrite;
770 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
771 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
772 });
773
774 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
775 __builtin_amdgcn_sched_group_barrier(0x008,
776 num_mfma_per_issue - num_dsread_per_issue_a -
777 num_dswrite_per_issue_a,
778 0); // MFMA
779 });
780
781 static_for<0, num_issue_b, 1>{}([&](auto i) {
782 ignore = i;
783 static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
784 ignore = idsread;
785 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
786 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
787 });
788
789 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
790 ignore = idswrite;
791 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
792 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
793 });
794
795 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
796 __builtin_amdgcn_sched_group_barrier(0x008,
797 num_mfma_per_issue - num_dsread_per_issue_a -
798 num_dswrite_per_issue_b,
799 0); // MFMA
800 });
801 __builtin_amdgcn_sched_barrier(0);
802 }
803
804 template <bool HasMainLoop,
805 TailNumber TailNum,
806 typename AGridDesc,
807 typename ABlockDesc,
808 typename ABlockTransfer,
809 typename AGridBuffer,
810 typename ABlockBuffer,
811 typename ABlockTransferStep,
812 typename BGridDesc,
813 typename BBlockDesc,
814 typename BBlockTransfer,
815 typename BGridBuffer,
816 typename BBlockBuffer,
817 typename BBlockTransferStep,
818 typename CThreadBuffer>
819 __device__ void Run(const AGridDesc& a_grid_desc,
820 const ABlockDesc& a_block_desc,
821 ABlockTransfer& a_blockwise_copy,
822 const AGridBuffer& a_grid_buf,
823 ABlockBuffer& a_block_buf,
824 const ABlockTransferStep& a_block_copy_step,
825 const BGridDesc& b_grid_desc,
826 const BBlockDesc& b_block_desc,
827 BBlockTransfer& b_blockwise_copy,
828 const BGridBuffer& b_grid_buf,
829 BBlockBuffer& b_block_buf,
830 const BBlockTransferStep& b_block_copy_step,
831 CThreadBuffer& c_thread_buf,
832 index_t num_loop) const
833 {
835 a_thread_desc_.GetElementSpaceSize());
837 b_thread_desc_.GetElementSpaceSize());
838
839 StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
840 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
841
842 // Global prefetch 1
843 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I0));
844 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I0));
845
846 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
847 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
848
850
851 // Local prefetch 1
852 static_for<0, KRepeat, 1>{}([&](auto k) {
853 static_for<0, MRepeat, 1>{}([&](auto m0) {
856 a_block_buf.At(I0),
858 make_tuple(m0, I0, k, I0),
859 a_thread_bufs(I0));
860 });
861 static_for<0, NRepeat, 1>{}([&](auto n0) {
864 b_block_buf.At(I0),
866 make_tuple(n0, I0, k, I0),
867 b_thread_bufs(I0));
868 });
869 });
870
871 // Global prefetch 2
872 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I1));
873 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I1));
874
875 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
876 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
877
878 // Initialize C
879 c_thread_buf.Clear();
880
881 // main body
882 if constexpr(HasMainLoop)
883 {
884 index_t i = 0;
885 // This hot loop has two legacy loopover, to implement the double local buffer strategy
886 do
887 {
888 auto LoopFunc = [&](auto lds_read_buf,
889 auto lds_read_reg_buf,
890 auto lds_write_buf,
891 auto mfma_reg_buf) {
893
894 static_for<0, KRepeat, 1>{}([&](auto k) {
895 static_for<0, MRepeat, 1>{}([&](auto m0) {
898 a_block_buf.At(lds_read_buf),
900 make_tuple(m0, I0, k, I0),
901 a_thread_bufs(lds_read_reg_buf));
902 });
903 static_for<0, NRepeat, 1>{}([&](auto n0) {
906 b_block_buf.At(lds_read_buf),
908 make_tuple(n0, I0, k, I0),
909 b_thread_bufs(lds_read_reg_buf));
910 });
911 });
912
913 a_blockwise_copy.Run(
914 a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
915 b_blockwise_copy.Run(
916 b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
917
918 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
919 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
920
921 static_for<0, KRepeat, 1>{}([&](auto k0) {
922 static_for<0, MRepeat, 1>{}([&](auto m0) {
923 static_for<0, NRepeat, 1>{}([&](auto n0) {
926
927 static_for<0, KPack, 1>{}([&](auto ik) {
928 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
929 a_thread_bufs[mfma_reg_buf]
930 [Number<a_thread_desc_.CalculateOffset(
931 make_tuple(m0, I0, k0, ik))>{}];
932 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
933 b_thread_bufs[mfma_reg_buf]
934 [Number<b_thread_desc_.CalculateOffset(
935 make_tuple(n0, I0, k0, ik))>{}];
936 });
937
938 using mfma_input_type =
940 xdlops_gemm.K1PerXdlops>::type;
941
942 constexpr index_t c_offset =
943 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
944
945 xdlops_gemm.Run(
946 a_thread_vec.template AsType<mfma_input_type>(),
947 b_thread_vec.template AsType<mfma_input_type>(),
948 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
949 });
950 });
951 });
952
954 };
955
956 LoopFunc(I1, I1, I0, I0);
957 LoopFunc(I0, I0, I1, I1);
958
959 i += HotloopUnroll;
960 } while(i < (num_loop - PrefetchStages));
961 }
962
963 auto ReadWriteCompFunc = [&](auto lds_read_buf,
964 auto lds_read_reg_buf,
965 auto lds_write_buf,
966 auto mfma_reg_buf) {
968
969 static_for<0, KRepeat, 1>{}([&](auto k) {
970 static_for<0, MRepeat, 1>{}([&](auto m0) {
973 a_block_buf.At(lds_read_buf),
975 make_tuple(m0, I0, k, I0),
976 a_thread_bufs(lds_read_reg_buf));
977 });
978 static_for<0, NRepeat, 1>{}([&](auto n0) {
981 b_block_buf.At(lds_read_buf),
983 make_tuple(n0, I0, k, I0),
984 b_thread_bufs(lds_read_reg_buf));
985 });
986 });
987
988 a_blockwise_copy.Run(
989 a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
990 b_blockwise_copy.Run(
991 b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
992
993 static_for<0, KRepeat, 1>{}([&](auto k0) {
994 static_for<0, MRepeat, 1>{}([&](auto m0) {
995 static_for<0, NRepeat, 1>{}([&](auto n0) {
998
999 static_for<0, KPack, 1>{}([&](auto ik) {
1000 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1001 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1002 make_tuple(m0, I0, k0, ik))>{}];
1003 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1004 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1005 make_tuple(n0, I0, k0, ik))>{}];
1006 });
1007
1008 using mfma_input_type =
1009 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1010
1011 constexpr index_t c_offset =
1012 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1013
1014 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1015 b_thread_vec.template AsType<mfma_input_type>(),
1016 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1017 });
1018 });
1019 });
1020
1022 };
1023
1024 auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
1026
1027 static_for<0, KRepeat, 1>{}([&](auto k) {
1028 static_for<0, MRepeat, 1>{}([&](auto m0) {
1031 a_block_buf.At(lds_read_buf),
1033 make_tuple(m0, I0, k, I0),
1034 a_thread_bufs(lds_read_reg_buf));
1035 });
1036 static_for<0, NRepeat, 1>{}([&](auto n0) {
1039 b_block_buf.At(lds_read_buf),
1041 make_tuple(n0, I0, k, I0),
1042 b_thread_bufs(lds_read_reg_buf));
1043 });
1044 });
1045
1046 static_for<0, KRepeat, 1>{}([&](auto k0) {
1047 static_for<0, MRepeat, 1>{}([&](auto m0) {
1048 static_for<0, NRepeat, 1>{}([&](auto n0) {
1051
1052 static_for<0, KPack, 1>{}([&](auto ik) {
1053 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1054 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1055 make_tuple(m0, I0, k0, ik))>{}];
1056 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1057 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1058 make_tuple(n0, I0, k0, ik))>{}];
1059 });
1060
1061 using mfma_input_type =
1062 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1063
1064 constexpr index_t c_offset =
1065 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1066
1067 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1068 b_thread_vec.template AsType<mfma_input_type>(),
1069 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1070 });
1071 });
1072 });
1073
1075 };
1076
1077 auto CompFunc = [&](auto mfma_reg_buf) {
1078 static_for<0, KRepeat, 1>{}([&](auto k0) {
1079 static_for<0, MRepeat, 1>{}([&](auto m0) {
1080 static_for<0, NRepeat, 1>{}([&](auto n0) {
1083
1084 static_for<0, KPack, 1>{}([&](auto ik) {
1085 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1086 a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1087 make_tuple(m0, I0, k0, ik))>{}];
1088 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1089 b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1090 make_tuple(n0, I0, k0, ik))>{}];
1091 });
1092
1093 using mfma_input_type =
1094 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1095
1096 constexpr index_t c_offset =
1097 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1098
1099 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1100 b_thread_vec.template AsType<mfma_input_type>(),
1101 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1102 });
1103 });
1104 });
1105 };
1106 // tail
1107 if constexpr(TailNum == TailNumber::Odd)
1108 {
1109 ReadWriteCompFunc(I1, I1, I0, I0);
1110 ReadCompFunc(I0, I0, I1);
1111 CompFunc(I0);
1112 }
1113 else if constexpr(TailNum == TailNumber::Even)
1114 {
1115 ReadCompFunc(I1, I1, I0);
1116 CompFunc(I1);
1117 }
1118 }
1119
1120 protected:
1121 using Base::a_thread_copy_;
1122 using Base::a_thread_desc_;
1123 using Base::b_thread_copy_;
1124 using Base::b_thread_desc_;
1125 using Base::c_thread_desc_;
1126};
1127
1128} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
__device__ void block_sync_lds_direct_load()
Definition synchronization.hpp:43
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:254
Definition blockwise_gemm_pipeline_xdlops.hpp:103
static __device__ constexpr auto HotLoopScheduler()
Definition blockwise_gemm_pipeline_xdlops.hpp:373
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:669
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:819
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:604
Definition functional2.hpp:33
Definition dtype_vector.hpp:10